From 16ca7acca881e167a3d42c903e6f3872a72c9f6a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 4 Aug 2015 16:14:40 -0700 Subject: [PATCH 01/39] It compiles!! --- project/SparkBuild.scala | 2 +- .../spark/sql/test/TestSQLContext.scala | 56 ---- .../spark/sql/JavaApplySchemaSuite.java | 7 +- .../apache/spark/sql/JavaDataFrameSuite.java | 31 ++- .../org/apache/spark/sql/JavaUDFSuite.java | 7 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 7 +- .../apache/spark/sql/CachedTableSuite.scala | 12 +- .../spark/sql/ColumnExpressionSuite.scala | 12 +- .../spark/sql/DataFrameAggregateSuite.scala | 10 +- .../spark/sql/DataFrameFunctionsSuite.scala | 8 +- .../spark/sql/DataFrameImplicitsSuite.scala | 5 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 8 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 5 +- .../apache/spark/sql/DataFrameStatSuite.scala | 10 +- .../org/apache/spark/sql/DataFrameSuite.scala | 14 +- .../spark/sql/DataFrameTungstenSuite.scala | 10 +- .../apache/spark/sql/DateFunctionsSuite.scala | 6 +- .../org/apache/spark/sql/JoinSuite.scala | 12 +- .../apache/spark/sql/ListTablesSuite.scala | 6 +- .../spark/sql/MathExpressionsSuite.scala | 6 +- .../scala/org/apache/spark/sql/RowSuite.scala | 6 +- .../org/apache/spark/sql/SQLConfSuite.scala | 6 +- .../apache/spark/sql/SQLContextSuite.scala | 9 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 19 +- .../sql/ScalaReflectionRelationSuite.scala | 6 +- .../apache/spark/sql/SerializationSuite.scala | 9 +- .../spark/sql/StringFunctionsSuite.scala | 6 +- .../scala/org/apache/spark/sql/TestData.scala | 197 -------------- .../scala/org/apache/spark/sql/UDFSuite.scala | 13 +- .../spark/sql/UserDefinedTypeSuite.scala | 6 +- .../columnar/InMemoryColumnarQuerySuite.scala | 13 +- .../columnar/PartitionBatchPruningSuite.scala | 11 +- .../spark/sql/execution/AggregateSuite.scala | 14 +- .../spark/sql/execution/PlannerSuite.scala | 19 +- .../execution/RowFormatConvertersSuite.scala | 16 +- .../spark/sql/execution/SparkPlanTest.scala | 15 +- .../execution/SparkSqlSerializer2Suite.scala | 15 +- .../sql/execution/TungstenSortSuite.scala | 16 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 10 +- .../UnsafeKVExternalSorterSuite.scala | 10 +- .../sql/execution/debug/DebuggingSuite.scala | 11 +- .../sql/execution/joins/SemiJoinSuite.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 11 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 11 +- .../org/apache/spark/sql/json/JsonSuite.scala | 20 +- .../apache/spark/sql/json/TestJsonData.scala | 2 + .../ParquetAvroCompatibilitySuite.scala | 6 +- .../parquet/ParquetCompatibilityTest.scala | 5 +- .../sql/parquet/ParquetFilterSuite.scala | 4 +- .../spark/sql/parquet/ParquetIOSuite.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 8 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 4 +- .../sql/parquet/ParquetSchemaSuite.scala | 2 - .../spark/sql/parquet/ParquetTest.scala | 10 +- .../ParquetThriftCompatibilitySuite.scala | 5 +- .../spark/sql/sources/DataSourceTest.scala | 6 +- .../apache/spark/sql/test/MyTestData.scala | 253 ++++++++++++++++++ .../spark/sql/test/MyTestSQLContext.scala | 118 ++++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 24 +- .../apache/spark/sql/hive/test/TestHive.scala | 16 +- .../spark/sql/hive/HiveParquetSuite.scala | 8 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 21 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 17 +- .../execution/AggregationQuerySuite.scala | 18 +- .../sql/hive/execution/SQLQuerySuite.scala | 36 +-- .../hive/orc/OrcHadoopFsRelationSuite.scala | 7 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 8 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 16 +- .../apache/spark/sql/hive/parquetSuites.scala | 27 +- .../CommitFailureTestRelationSuite.scala | 17 +- .../ParquetHadoopFsRelationSuite.scala | 7 +- .../SimpleTextHadoopFsRelationSuite.scala | 5 +- .../sql/sources/hadoopFsRelationSuites.scala | 17 +- 73 files changed, 776 insertions(+), 600 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TestData.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/MyTestData.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9a33baa7c6ce..ebbcd9a48243 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -312,6 +312,7 @@ object OldDeps { ) } +// TODO: check if this is OK object SQL { lazy val settings = Seq( initialCommands in console := @@ -325,7 +326,6 @@ object SQL { |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.test.TestSQLContext._ |import org.apache.spark.sql.types._""".stripMargin, cleanupCommands in console := "sparkContext.stop()" ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala deleted file mode 100644 index b3a4231da91c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import scala.language.implicitConversions - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** A SQLContext that can be used for local testing. */ -class LocalSQLContext - extends SQLContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf() - .set("spark.sql.testkey", "true") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) { - - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - - protected[sql] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** Fewer partitions to speed up testing. */ - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) - } - } - - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to - * construct [[DataFrame]] directly out of local data without relying on implicits. - */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(this, plan) - } - -} - -object TestSQLContext extends LocalSQLContext - diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index cb84e78d628c..478210d72d02 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -27,6 +27,7 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -34,7 +35,6 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -48,8 +48,9 @@ public class JavaApplySchemaSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - javaCtx = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext context = new SparkContext("local[*]", "testing"); + javaCtx = new JavaSparkContext(context); + sqlContext = new SQLContext(context); } @After diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 2c669bb59a0b..f9444f60a08d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,26 +17,24 @@ package test.org.apache.spark.sql; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import scala.collection.JavaConversions; +import scala.collection.Seq; + import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; +import org.junit.*; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; -import org.junit.*; - -import scala.collection.JavaConversions; -import scala.collection.Seq; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; - import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { @@ -46,9 +44,10 @@ public class JavaDataFrameSuite { @Before public void setUp() { // Trigger static initializer of TestData - TestData$.MODULE$.testData(); - jsc = new JavaSparkContext(TestSQLContext.sparkContext()); - context = TestSQLContext$.MODULE$; + // TODO: restore the test data here somehow: TestData$.MODULE$.testData(); + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new SQLContext(sc); } @After diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 79d92734ff37..98c8a4aca6ca 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -23,12 +23,12 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -40,8 +40,9 @@ public class JavaUDFSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); } @After diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 2706e01bd28a..90f802a01ef5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -25,9 +25,9 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -52,8 +52,9 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index eb3e91332206..fb012e2ec3b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -23,18 +23,16 @@ import scala.language.{implicitConversions, postfixOps} import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} +import org.apache.spark.sql.test.MyTestSQLContext -case class BigData(s: String) +private case class BigData(s: String) -class CachedTableSuite extends QueryTest { - TestData // Load test tables. - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class CachedTableSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContextWithData import ctx.implicits._ - import ctx.sql + import ctx._ def rddIdOf(tableName: String): Int = { val executedPlan = ctx.table(tableName).queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 35ca0b4c7cc2..5f6581b23a44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,15 +22,15 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} -class ColumnExpressionSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class ColumnExpressionSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { + private val ctx = sqlContextWithData import ctx.implicits._ + import ctx._ - override def sqlContext(): SQLContext = ctx + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f9cff7440a76..1b7d4f60c9f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{BinaryType, DecimalType} +import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.types.DecimalType -class DataFrameAggregateSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class DataFrameAggregateSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContextWithData import ctx.implicits._ + import ctx._ test("groupBy") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 431dcf7382f1..cf8058b17fd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class DataFrameFunctionsSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContextWithData import ctx.implicits._ + import ctx._ test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index fbb30706a494..9b598112732a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql -class DataFrameImplicitsSuite extends QueryTest { +import org.apache.spark.sql.test.MyTestSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class DataFrameImplicitsSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContext import ctx.implicits._ test("RDD of tuples") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e1c6c706242d..2a628d2b1345 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.MyTestSQLContext -class DataFrameJoinSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class DataFrameJoinSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContextWithData import ctx.implicits._ + import ctx._ test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index dbe3b44ee2c7..f0ae14935422 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ +import org.apache.spark.sql.test.MyTestSQLContext -class DataFrameNaFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class DataFrameNaFunctionsSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContext import ctx.implicits._ def createDF(): DataFrame = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 07a675e64f52..9096af0251b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -22,11 +22,11 @@ import java.util.Random import org.scalatest.Matchers._ import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.MyTestSQLContext -class DataFrameStatSuite extends QueryTest { - - private val sqlCtx = org.apache.spark.sql.test.TestSQLContext - import sqlCtx.implicits._ +class DataFrameStatSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContext + import ctx.implicits._ private def toLetter(i: Int): String = (i + 97).toChar.toString @@ -132,7 +132,7 @@ class DataFrameStatSuite extends QueryTest { } test("sampleBy") { - val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val df = ctx.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index aef940a52667..77493997296d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,13 +28,15 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.json.JSONRelation import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils, MyTestSQLContext} -class DataFrameSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ +class DataFrameSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { + private val ctx = sqlContextWithData + import ctx.implicits._ + import ctx._ - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ + // For SQLTestUtils + protected override def _sqlContext = ctx test("analysis error should be eagerly reported") { // Eager analysis. @@ -298,7 +300,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("udf") { - val foo = udf((a: Int, b: String) => a.toString + b) + val foo = org.apache.spark.sql.functions.udf((a: Int, b: String) => a.toString + b) checkAnswer( // SELECT *, foo(key, value) FROM testData diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index bf8ef9a97bc6..c2e516aefe19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ /** @@ -27,10 +27,12 @@ import org.apache.spark.sql.types._ * This is here for now so I can make sure Tungsten project is tested without refactoring existing * end-to-end test infra. In the long run this should just go away. */ -class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { +class DataFrameTungstenSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { + private val ctx = sqlContext + import ctx.implicits._ - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ + // For SQLTestUtils + protected override def _sqlContext = ctx test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 17897caf952a..21a3515ecb08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,11 +22,11 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.unsafe.types.CalendarInterval -class DateFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - +class DateFunctionsSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContext import ctx.implicits._ test("function current_date") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5bef1d896603..a15a2a5a6e0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -19,19 +19,15 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.test.MyTestSQLContext -class JoinSuite extends QueryTest with BeforeAndAfterEach { - // Ensures tables are loaded. - TestData - - lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class JoinSuite extends QueryTest with BeforeAndAfterEach with MyTestSQLContext { + private val ctx = sqlContextWithData import ctx.implicits._ - import ctx.logicalPlanToSparkQuery + import ctx._ test("equi-join is hash-join") { val x = testData2.as("x") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2089660c52bf..4b366482153f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -class ListTablesSuite extends QueryTest with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class ListTablesSuite extends QueryTest with BeforeAndAfter with MyTestSQLContext { + private val ctx = sqlContext import ctx.implicits._ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 8cf2ef5957d8..bf2525e1568a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} +import org.apache.spark.sql.test.MyTestSQLContext private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest { - +class MathExpressionsSuite extends QueryTest with MyTestSQLContext { import MathExpressionsTestData._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + private val ctx = sqlContext import ctx.implicits._ private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 8a679c7865d6..be0ad1d67fdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class RowSuite extends SparkFunSuite with MyTestSQLContext { + private val ctx = sqlContext import ctx.implicits._ test("create row") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 75791e9d53c2..eca293ad20d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.MyTestSQLContext -class SQLConfSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLConfSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContext private val testKey = "test.key.0" private val testVal = "test.val.0" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index c8d8796568a4..ee77e7b72e27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.MyTestSQLContext -class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLContextSuite extends SparkFunSuite with MyTestSQLContext { + private val ctx = sqlContext override def afterAll(): Unit = { SQLContext.setLastInstantiatedContext(ctx) + super.afterAll() } test("getOrCreate instantiates SQLContext") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 29dfcf257522..d0b229219445 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql import java.sql.Timestamp -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect @@ -28,20 +26,19 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { - // Make sure the tables are loaded. - TestData +class SQLQuerySuite extends QueryTest with SQLTestUtils with MyTestSQLContext { + private val ctx = sqlContextWithData + import ctx.implicits._ + import ctx._ - val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") @@ -179,7 +176,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUDF = udf(() => UUID.randomUUID().toString) + val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString) val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ab6d3dd96d27..9a374f52661d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.MyTestSQLContext case class ReflectData( stringField: String, @@ -71,9 +72,8 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class ScalaReflectionRelationSuite extends SparkFunSuite with MyTestSQLContext { + private val ctx = sqlContext import ctx.implicits._ test("query case class RDD") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index e55c9e460b79..31abd7835268 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.MyTestSQLContext -class SerializationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SerializationSuite extends SparkFunSuite with MyTestSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(ctx.sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) + val _sqlContext = new SQLContext(sqlContext.sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ab5da6ee79f1..b6a1b6c67917 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types.Decimal -class StringFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class StringFunctionsSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContext import ctx.implicits._ test("string concat") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala deleted file mode 100644 index bd9729c431f3..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test._ - - -case class TestData(key: Int, value: String) - -object TestData { - val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") - - val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() - negativeData.registerTempTable("negativeData") - - case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts = - TestSQLContext.sparkContext.parallelize( - LargeAndSmallInts(2147483644, 1) :: - LargeAndSmallInts(1, 2) :: - LargeAndSmallInts(2147483645, 1) :: - LargeAndSmallInts(2, 2) :: - LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDF() - largeAndSmallInts.registerTempTable("largeAndSmallInts") - - case class TestData2(a: Int, b: Int) - val testData2 = - TestSQLContext.sparkContext.parallelize( - TestData2(1, 1) :: - TestData2(1, 2) :: - TestData2(2, 1) :: - TestData2(2, 2) :: - TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDF() - testData2.registerTempTable("testData2") - - case class DecimalData(a: BigDecimal, b: BigDecimal) - - val decimalData = - TestSQLContext.sparkContext.parallelize( - DecimalData(1, 1) :: - DecimalData(1, 2) :: - DecimalData(2, 1) :: - DecimalData(2, 2) :: - DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDF() - decimalData.registerTempTable("decimalData") - - case class BinaryData(a: Array[Byte], b: Int) - val binaryData = - TestSQLContext.sparkContext.parallelize( - BinaryData("12".getBytes(), 1) :: - BinaryData("22".getBytes(), 5) :: - BinaryData("122".getBytes(), 3) :: - BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDF() - binaryData.registerTempTable("binaryData") - - case class TestData3(a: Int, b: Option[Int]) - val testData3 = - TestSQLContext.sparkContext.parallelize( - TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDF() - testData3.registerTempTable("testData3") - - case class UpperCaseData(N: Int, L: String) - val upperCaseData = - TestSQLContext.sparkContext.parallelize( - UpperCaseData(1, "A") :: - UpperCaseData(2, "B") :: - UpperCaseData(3, "C") :: - UpperCaseData(4, "D") :: - UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDF() - upperCaseData.registerTempTable("upperCaseData") - - case class LowerCaseData(n: Int, l: String) - val lowerCaseData = - TestSQLContext.sparkContext.parallelize( - LowerCaseData(1, "a") :: - LowerCaseData(2, "b") :: - LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDF() - lowerCaseData.registerTempTable("lowerCaseData") - - case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - val arrayData = - TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: - ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) - arrayData.toDF().registerTempTable("arrayData") - - case class MapData(data: scala.collection.Map[Int, String]) - val mapData = - TestSQLContext.sparkContext.parallelize( - MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: - MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: - MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: - MapData(Map(1 -> "a4", 2 -> "b4")) :: - MapData(Map(1 -> "a5")) :: Nil) - mapData.toDF().registerTempTable("mapData") - - case class StringData(s: String) - val repeatedData = - TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.toDF().registerTempTable("repeatedData") - - val nullableRepeatedData = - TestSQLContext.sparkContext.parallelize( - List.fill(2)(StringData(null)) ++ - List.fill(2)(StringData("test"))) - nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData") - - case class NullInts(a: Integer) - val nullInts = - TestSQLContext.sparkContext.parallelize( - NullInts(1) :: - NullInts(2) :: - NullInts(3) :: - NullInts(null) :: Nil - ).toDF() - nullInts.registerTempTable("nullInts") - - val allNulls = - TestSQLContext.sparkContext.parallelize( - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: Nil).toDF() - allNulls.registerTempTable("allNulls") - - case class NullStrings(n: Int, s: String) - val nullStrings = - TestSQLContext.sparkContext.parallelize( - NullStrings(1, "abc") :: - NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDF() - nullStrings.registerTempTable("nullStrings") - - case class TableName(tableName: String) - TestSQLContext - .sparkContext - .parallelize(TableName("test") :: Nil) - .toDF() - .registerTempTable("tableName") - - val unparsedStrings = - TestSQLContext.sparkContext.parallelize( - "1, A1, true, null" :: - "2, B2, false, null" :: - "3, C3, true, null" :: - "4, D4, true, 2147483644" :: Nil) - - case class IntField(i: Int) - // An RDD with 4 elements and 8 partitions - val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.toDF().registerTempTable("withEmptyParts") - - case class Person(id: Int, name: String, age: Int) - case class Salary(personId: Int, salary: Double) - val person = TestSQLContext.sparkContext.parallelize( - Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil).toDF() - person.registerTempTable("person") - val salary = TestSQLContext.sparkContext.parallelize( - Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil).toDF() - salary.registerTempTable("salary") - - case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) - val complexData = - TestSQLContext.sparkContext.parallelize( - ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) - :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) - :: Nil).toDF() - complexData.registerTempTable("complexData") -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 183dc3407b3a..7d379d8a5972 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} -case class FunctionResult(f1: String, f2: String) +private case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with SQLTestUtils { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class UDFSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { + private val ctx = sqlContextWithData import ctx.implicits._ + import ctx._ - override def sqlContext(): SQLContext = ctx + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index f29935224e5b..f060851a363a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -66,9 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class UserDefinedTypeSuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContext import ctx.implicits._ private lazy val pointsRDD = Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 037e2048a863..6ee417d69fe4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,18 +19,15 @@ package org.apache.spark.sql.columnar import java.sql.{Date, Timestamp} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, TestData} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY -class InMemoryColumnarQuerySuite extends QueryTest { - // Make sure the tables are loaded. - TestData - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class InMemoryColumnarQuerySuite extends QueryTest with MyTestSQLContext { + private val ctx = sqlContextWithData import ctx.implicits._ - import ctx.{logicalPlanToSparkQuery, sql} + import ctx._ test("simple columnar query") { val plan = ctx.executePlan(testData.logicalPlan).executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2c0879927a12..7c3754f84595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,20 +17,22 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.MyTestSQLContext -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with MyTestSQLContext { + private val ctx = sqlContextWithData import ctx.implicits._ + import ctx._ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { + super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) @@ -49,6 +51,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi override protected def afterAll(): Unit = { ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + super.afterAll() } before { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala index 20def6bef0c1..bd022ec26111 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext class AggregateSuite extends SparkPlanTest { + private val ctx = sqlContext test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") { - val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED) - val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED) + val codegenDefault = ctx.getConf(SQLConf.CODEGEN_ENABLED) + val unsafeDefault = ctx.getConf(SQLConf.UNSAFE_ENABLED) try { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true) - TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true) + ctx.setConf(SQLConf.CODEGEN_ENABLED, true) + ctx.setConf(SQLConf.UNSAFE_ENABLED, true) val df = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( df, @@ -41,8 +41,8 @@ class AggregateSuite extends SparkPlanTest { Seq.empty ) } finally { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) - TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault) + ctx.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + ctx.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 18b0e54dc7c5..2c7b93ccbf82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,22 +18,23 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution} +import org.apache.spark.sql.{execution, Row, SQLConf, SQLContext} -class PlannerSuite extends SparkFunSuite with SQLTestUtils { +class PlannerSuite extends SparkFunSuite with SQLTestUtils with MyTestSQLContext { + private val ctx = sqlContextWithData + import ctx.implicits._ + import ctx.planner._ + import ctx._ - override def sqlContext: SQLContext = TestSQLContext + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) @@ -83,7 +84,7 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } :+ StructField("key", IntegerType, true) val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) + val rowRDD = ctx.sparkContext.parallelize(row :: Nil) createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 707cd9c6d939..dc2c7b6fa185 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} -import org.apache.spark.sql.test.TestSQLContext class RowFormatConvertersSuite extends SparkPlanTest { + private val ctx = sqlContext private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { case c: ConvertToUnsafe => c @@ -35,20 +35,20 @@ class RowFormatConvertersSuite extends SparkPlanTest { test("planner should insert unsafe->safe conversions when required") { val plan = Limit(10, outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) } test("filter can process unsafe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) } @@ -63,26 +63,26 @@ class RowFormatConvertersSuite extends SparkPlanTest { test("union requires all of its input rows' formats to agree") { val plan = Union(Seq(outputsSafe, outputsUnsafe)) assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("union can process safe rows") { val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(!preparedPlan.outputsUnsafeRows) } test("union can process unsafe rows") { val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("round trip with ConvertToUnsafe and ConvertToSafe") { val input = Seq(("hello", 1), ("world", 2)) checkAnswer( - TestSQLContext.createDataFrame(input), + ctx.createDataFrame(input), plan => ConvertToSafe(ConvertToUnsafe(plan)), input.map(Row.fromTuple) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index f46855edfe0d..4419f09e6cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -17,23 +17,22 @@ package org.apache.spark.sql.execution +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row} +import org.apache.spark.sql.test.MyTestSQLContext -import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -class SparkPlanTest extends SparkFunSuite { - - protected def sqlContext: SQLContext = TestSQLContext +abstract class SparkPlanTest extends SparkFunSuite with MyTestSQLContext { /** * Creates a DataFrame from a local Seq of Product. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 7978ed57a937..f55503ba1d74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -19,17 +19,14 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} -import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.serializer.Serializer import org.apache.spark.{ShuffleDependency, SparkFunSuite} +import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.Row -import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} -class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { +class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite with MyTestSQLContext { // Make sure that we will not use serializer2 for unsupported data types. def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { val testName = @@ -68,15 +65,15 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { checkSupported(new MyDenseVectorUDT, isSupported = false) } -abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { +abstract class SparkSqlSerializer2Suite extends QueryTest with MyTestSQLContext { + protected val ctx = sqlContext + var allColumns: String = _ val serializerClass: Class[Serializer] = classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] var numShufflePartitions: Int = _ var useSerializer2: Boolean = _ - protected lazy val ctx = TestSQLContext - override def beforeAll(): Unit = { numShufflePartitions = ctx.conf.numShufflePartitions useSerializer2 = ctx.conf.useSqlSerializer2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 88bce0e319f9..8ce75e8f2318 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -19,25 +19,23 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ /** * A test suite that generates randomized data to test the [[TungstenSort]] operator. */ -class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { +class TungstenSortSuite extends SparkPlanTest { + private val ctx = sqlContext override def beforeAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) } test("sort followed by limit") { @@ -61,7 +59,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { } test("sorting updates peak execution memory") { - val sc = TestSQLContext.sparkContext + val sc = ctx.sparkContext AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), @@ -80,8 +78,8 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = TestSQLContext.createDataFrame( - TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + val inputDf = ctx.createDataFrame( + ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) assert(TungstenSort.supportsSchema(inputDf.schema)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index ef827b0fe9b5..9f1101fe6aa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String * * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases. */ -class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with MyTestSQLContext { import UnsafeFixedWidthAggregationMap._ @@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext - // Memory consumption in the beginning of the task. val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 601a5a07ad00..bb4ab9f1e986 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -22,16 +22,15 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection} -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.catalyst.expressions.{RowOrdering, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. */ -class UnsafeKVExternalSorterSuite extends SparkFunSuite { - +class UnsafeKVExternalSorterSuite extends SparkFunSuite with MyTestSQLContext { private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) @@ -65,9 +64,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { val valueSchemaStr = valueSchema.map(_.dataType.simpleString).mkString("[", ",", "]") test(s"kv sorting key schema $keySchemaStr and value schema $valueSchemaStr") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext - val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) val shuffleMemMgr = new TestShuffleMemoryManager TaskContext.setTaskContext(new TaskContextImpl( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8ec3985e0036..12ef64252a7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,15 +18,16 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.MyTestSQLContext + +class DebuggingSuite extends SparkFunSuite with MyTestSQLContext { + private val ctx = sqlContextWithData -class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { - testData.debug() + ctx.testData.debug() } test("DataFrame.typeCheck()") { - testData.typeCheck() + ctx.testData.typeCheck() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 927e85a7db3d..17720c1c5da3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} -class SemiJoinSuite extends SparkPlanTest{ +class SemiJoinSuite extends SparkPlanTest { val left = Seq( (1, 2.0), (1, 2.0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 42f2449afb0f..6f1206230767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,10 +25,15 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with MyTestSQLContext { + private val ctx = sqlContext + import ctx.implicits._ + import ctx._ + val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null @@ -42,10 +47,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { Some(StringType) } - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 84b52ca2c733..073171743683 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -24,10 +24,15 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with MyTestSQLContext { + private val ctx = sqlContext + import ctx.implicits._ + import ctx._ + val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -37,10 +42,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") conn = DriverManager.getConnection(url) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index f19f22fca7d5..7405cef320df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -23,19 +23,20 @@ import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.json.InferSchema.compatibleType +import org.apache.spark.sql.test.MyTestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JsonSuite extends QueryTest with TestJsonData { +class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { + private val _ctx = sqlContextWithData + import _ctx.implicits._ + import _ctx._ - protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.sql - import ctx.implicits._ + protected override def ctx: SQLContext = _ctx test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -1040,24 +1041,23 @@ class JsonSuite extends QueryTest with TestJsonData { } test("JSONRelation equality test") { - val context = org.apache.spark.sql.test.TestSQLContext val relation1 = new JSONRelation( "path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - context) + ctx) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( "path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - context) + ctx) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( "path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)), - context) + ctx) val logicalRelation3 = LogicalRelation(relation3) assert(relation1 === relation2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index eb62066ac643..182063097e32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext +// TODO: clean me up + trait TestJsonData { protected def ctx: SQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala index bfa427349ff6..6d7a923097dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala @@ -26,17 +26,13 @@ import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter import org.apache.spark.sql.parquet.test.avro.{Nested, ParquetAvroCompat} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { import ParquetCompatibilityTest._ - override val sqlContext: SQLContext = TestSQLContext - override protected def beforeAll(): Unit = { super.beforeAll() - val writer = new AvroParquetWriter[ParquetAvroCompat]( new Path(parquetStore.getCanonicalPath), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala index 57478931cd50..1c8a1c4def24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala @@ -23,12 +23,11 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType -import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.QueryTest import org.apache.spark.util.Utils -abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { +abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { protected var parquetStore: File = _ /** @@ -40,12 +39,14 @@ abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with protected def stagingDir: Option[String] = None override protected def beforeAll(): Unit = { + super.beforeAll() parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") parquetStore.delete() } override protected def afterAll(): Unit = { Utils.deleteRecursively(parquetStore) + super.afterAll() } def readParquetSchema(path: String): MessageType = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index b6a7c4fbddbd..997779b4985f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * data type is nullable. */ class ParquetFilterSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + private val ctx = sqlContext private def checkFilterPredicate( df: DataFrame, @@ -301,7 +301,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("SPARK-6554: don't push down predicates which reference partition columns") { - import sqlContext.implicits._ + import ctx.implicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index b415da5b8c13..7035bde00d9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -63,8 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS * A test suite that tests basic Parquet I/O. */ class ParquetIOSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ + private val ctx = sqlContext + import ctx.implicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 2eef10189f11..9466fd51244a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpe import org.apache.spark.sql.types._ import org.apache.spark.sql._ import org.apache.spark.unsafe.types.UTF8String -import PartitioningUtils._ // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -41,10 +40,11 @@ case class ParquetData(intField: Int, stringField: String) case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { + import PartitioningUtils._ - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql + private val ctx = sqlContext + import ctx.implicits._ + import ctx._ val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index a95f70f2bba6..6ecea93ee3fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -29,8 +29,8 @@ import org.apache.spark.util.Utils * A test suite that tests various Parquet queries. */ class ParquetQuerySuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.sql + private val ctx = sqlContext + import ctx._ test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 4a0b3b60f419..b5c88743c870 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -24,11 +24,9 @@ import org.apache.parquet.schema.MessageTypeParser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { - val sqlContext = TestSQLContext /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 64e94056f209..9889c40236c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,8 +23,8 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.{SQLContext, DataFrame, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,7 +33,11 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => +private[sql] trait ParquetTest extends SparkFunSuite with SQLTestUtils with MyTestSQLContext { + + // For SQLTestUtils + protected override def _sqlContext: SQLContext = sqlContext + /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala index 1c532d78790d..7a75a133e5b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.parquet -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest { import ParquetCompatibilityTest._ - override val sqlContext: SQLContext = TestSQLContext - private val parquetFilePath = Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 00cc7d5ea580..fa253105bdb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.sources import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.MyTestSQLContext -abstract class DataSourceTest extends QueryTest with BeforeAndAfter { +abstract class DataSourceTest extends QueryTest with BeforeAndAfter with MyTestSQLContext { // We want to test some edge cases. protected implicit lazy val caseInsensitiveContext = { - val ctx = new SQLContext(TestSQLContext.sparkContext) + val ctx = new SQLContext(sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestData.scala new file mode 100644 index 000000000000..30ffb07ffb4a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestData.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + +/** + * A collection of sample data used in SQL tests. + */ +private[spark] trait MyTestData { + protected val sqlContext: SQLContext + import sqlContext.implicits._ + + // All test data should be lazy because the SQLContext is not set up yet + + lazy val testData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("testData") + df + } + + lazy val testData2: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, 2).toDF() + df.registerTempTable("testData2") + df + } + + lazy val testData3: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toDF() + df.registerTempTable("testData3") + df + } + + lazy val negativeData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() + df.registerTempTable("negativeData") + df + } + + lazy val largeAndSmallInts: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil).toDF() + df.registerTempTable("largeAndSmallInts") + df + } + + lazy val decimalData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toDF() + df.registerTempTable("decimalData") + df + } + + lazy val binaryData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + BinaryData("12".getBytes, 1) :: + BinaryData("22".getBytes, 5) :: + BinaryData("122".getBytes, 3) :: + BinaryData("121".getBytes, 2) :: + BinaryData("123".getBytes, 4) :: Nil).toDF() + df.registerTempTable("binaryData") + df + } + + lazy val upperCaseData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDF() + df.registerTempTable("upperCaseData") + df + } + + lazy val lowerCaseData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.registerTempTable("lowerCaseData") + df + } + + lazy val arrayData: RDD[ArrayData] = { + val rdd = sqlContext.sparkContext.parallelize( + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + rdd.toDF().registerTempTable("arrayData") + rdd + } + + lazy val mapData: RDD[MapData] = { + val rdd = sqlContext.sparkContext.parallelize( + MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + MapData(Map(1 -> "a4", 2 -> "b4")) :: + MapData(Map(1 -> "a5")) :: Nil) + rdd.toDF().registerTempTable("mapData") + rdd + } + + lazy val repeatedData: RDD[StringData] = { + val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("repeatedData") + rdd + } + + lazy val nullableRepeatedData: RDD[StringData] = { + val rdd = sqlContext.sparkContext.parallelize( + List.fill(2)(StringData(null)) ++ + List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("nullableRepeatedData") + rdd + } + + lazy val nullInts: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("nullInts") + df + } + + lazy val allNulls: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("allNulls") + df + } + + lazy val nullStrings: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil).toDF() + df.registerTempTable("nullStrings") + df + } + + lazy val tableName: DataFrame = { + val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.registerTempTable("tableName") + df + } + + lazy val unparsedStrings: RDD[String] = { + sqlContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + } + + // An RDD with 4 elements and 8 partitions + lazy val withEmptyParts: RDD[IntField] = { + val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().registerTempTable("withEmptyParts") + rdd + } + + lazy val person: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil).toDF() + df.registerTempTable("person") + df + } + + lazy val salary: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil).toDF() + df.registerTempTable("salary") + df + } + + lazy val complexData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDF() + df.registerTempTable("complexData") + df + } + + /* ------------------------------ * + | Case classes used in test data | + * ------------------------------ */ + + private[spark] case class TestData(key: Int, value: String) + private[spark] case class TestData2(a: Int, b: Int) + private[spark] case class TestData3(a: Int, b: Option[Int]) + private[spark] case class LargeAndSmallInts(a: Int, b: Int) + private[spark] case class DecimalData(a: BigDecimal, b: BigDecimal) + private[spark] case class BinaryData(a: Array[Byte], b: Int) + private[spark] case class UpperCaseData(N: Int, L: String) + private[spark] case class LowerCaseData(n: Int, l: String) + private[spark] case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + private[spark] case class MapData(data: scala.collection.Map[Int, String]) + private[spark] case class StringData(s: String) + private[spark] case class IntField(i: Int) + private[spark] case class NullInts(a: Integer) + private[spark] case class NullStrings(n: Int, s: String) + private[spark] case class TableName(tableName: String) + private[spark] case class Person(id: Int, name: String, age: Int) + private[spark] case class Salary(personId: Int, salary: Double) + private[spark] case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala new file mode 100644 index 000000000000..23f373596623 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import scala.language.implicitConversions + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * A SQLContext that can be used for local testing. + */ +private[spark] class MyLocalSQLContext(sc: SparkContext) extends SQLContext(sc) with MyTestData { + + def this() { + this(new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) + } + + // For test data + protected override val sqlContext: SQLContext = this + + override protected[sql] def createSession(): SQLSession = { + new this.SQLSession() + } + + protected[sql] class SQLSession extends super.SQLSession { + protected[sql] override lazy val conf: SQLConf = new SQLConf { + /** Fewer partitions to speed up testing. */ + override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) + } + } + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to + * construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + DataFrame(this, plan) + } +} + +/** + * A scalatest trait for test suites where all tests share a single [[SQLContext]]. + */ +private[spark] trait MyTestSQLContext extends SparkFunSuite with BeforeAndAfterAll { + + /** + * The [[SQLContext]] to use for all tests in this suite. + * + * By default, the underlying [[SparkContext]] will be run in local mode with the default + * test configurations. + */ + private var _ctx: SQLContext = new MyLocalSQLContext + + /** The [[SQLContext]] to use for all tests in this suite. */ + protected def sqlContext: SQLContext = _ctx + + /** + * The [[MyLocalSQLContext]] to use for all tests in this suite. + * This one comes with all the data prepared in advance. + */ + protected def sqlContextWithData: MyLocalSQLContext = { + _ctx match { + case local: MyLocalSQLContext => local + case _ => fail("this SQLContext does not have data prepared in advance") + } + } + + /** + * Switch the [[SQLContext]] with the one provided. + * + * This stops the underlying [[SparkContext]] and expects a new one to be created. + * This is needed because only one [[SparkContext]] is allowed per JVM. + */ + protected def switchSQLContext(newContext: () => SQLContext): Unit = { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = newContext() + } + } + + /** + * Execute the given block of code with a custom [[SQLContext]]. + * At the end of the method, a [[MyLocalSQLContext]] will be restored. + */ + protected def withSQLContext[T](newContext: () => SQLContext)(body: => T) { + switchSQLContext(newContext) + try { + body + } finally { + switchSQLContext(() => new MyLocalSQLContext) + } + } + + protected override def afterAll(): Unit = { + super.afterAll() + switchSQLContext(() => null) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 4c11acdab9ec..f141c15b99f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils trait SQLTestUtils { this: SparkFunSuite => - def sqlContext: SQLContext + protected def _sqlContext: SQLContext - protected def configuration = sqlContext.sparkContext.hadoopConfiguration + protected def configuration = _sqlContext.sparkContext.hadoopConfiguration /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -39,12 +39,12 @@ trait SQLTestUtils { this: SparkFunSuite => */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(_sqlContext.conf.setConfString) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => _sqlContext.conf.setConfString(key, value) + case (key, None) => _sqlContext.conf.unsetConf(key) } } } @@ -76,7 +76,7 @@ trait SQLTestUtils { this: SparkFunSuite => * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(sqlContext.dropTempTable) + try f finally tableNames.foreach(_sqlContext.dropTempTable) } /** @@ -85,7 +85,7 @@ trait SQLTestUtils { this: SparkFunSuite => protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - sqlContext.sql(s"DROP TABLE IF EXISTS $name") + _sqlContext.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -98,12 +98,12 @@ trait SQLTestUtils { this: SparkFunSuite => val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - sqlContext.sql(s"CREATE DATABASE $dbName") + _sqlContext.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -111,7 +111,7 @@ trait SQLTestUtils { this: SparkFunSuite => * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sql(s"USE $db") - try f finally sqlContext.sql(s"USE default") + _sqlContext.sql(s"USE $db") + try f finally _sqlContext.sql(s"USE default") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 8d0bf46e8fad..c659f4e8eb07 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -42,7 +42,7 @@ import org.apache.spark.{SparkConf, SparkContext} /* Implicit conversions */ import scala.collection.JavaConversions._ -// SPARK-3729: Test key required to check for initialization errors with config. +// TODO: remove it object TestHive extends TestHiveContext( new SparkContext( @@ -72,6 +72,20 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { import HiveContext._ + def this() { + this(new SparkContext( + System.getProperty("spark.sql.test.master", "local[32]"), + "TestSQLContext", + new SparkConf() + // SPARK-3729: Test key required to check for initialization errors with config. + .set("spark.sql.test", "") + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + .set("spark.buffer.pageSize", "4m") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) + } + // By clearing the port we force Spark to pick a new one. This allows us to rerun tests // without restarting the JVM. System.clearProperty("spark.hostPort") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index a45c2d957278..59464a048971 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.parquet.ParquetTest import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive - import sqlContext._ + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext + import ctx._ test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 4fdf774ead75..363cacfe7129 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,26 +23,33 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidInputException -import org.scalatest.BeforeAndAfterAll import org.apache.spark.Logging import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll +class MetastoreDataSourcesSuite + extends QueryTest + with SQLTestUtils + with MyTestSQLContext with Logging { - override val sqlContext = TestHive + + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext.asInstanceOf[TestHiveContext] + import ctx.implicits._ + import ctx._ + + protected override def _sqlContext: SQLContext = ctx var jsonFilePath: String = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 73852f13ad20..89a97ca54144 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} +import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.{QueryTest, SaveMode, SQLContext} -class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val sqlContext: SQLContext = TestHive +class MultiDatabaseSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { - import sqlContext.sql + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext + import ctx.sql + + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx private val df = sqlContext.range(10).coalesce(1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6f0db27775e4..c003e827bb16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -18,17 +18,21 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row} -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf, SQLContext} import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} -abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { +abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with MyTestSQLContext { - override val sqlContext = TestHive - import sqlContext.implicits._ + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext + import ctx.implicits._ + + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx var originalUseAggregate2: Boolean = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 95c1da6e9796..e03c743ecc12 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,12 +26,10 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ case class Nested1(f1: Nested2) @@ -64,11 +62,19 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest with SQLTestUtils { - override def sqlContext: SQLContext = TestHive +class SQLQuerySuite extends QueryTest with SQLTestUtils with MyTestSQLContext { + + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext.asInstanceOf[TestHiveContext] + import ctx.implicits._ + import ctx._ + + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx test("UDTF") { - sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + sql(s"ADD JAR ${getHiveFile("TestUDTF.jar").getCanonicalPath()}") // The function source code can be found at: // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF sql( @@ -612,7 +618,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val rowRdd = sparkContext.parallelize(row :: Nil) - TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") + createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes @@ -1030,10 +1036,9 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val thread = new Thread { override def run() { // To make sure this test works, this jar should not be loaded in another place. - TestHive.sql( - s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + ctx.sql(s"ADD JAR ${getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") try { - TestHive.sql( + ctx.sql( """ |CREATE TEMPORARY FUNCTION example_max |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' @@ -1082,22 +1087,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { } test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { - val df = - TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) + val df = ctx.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) df.toDF("id", "datef").registerTempTable("test_SPARK8588") checkAnswer( - TestHive.sql( + ctx.sql( """ |select id, concat(year(datef)) |from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year') """.stripMargin), Row(1, "2014") :: Row(2, "2015") :: Nil ) - TestHive.dropTempTable("test_SPARK8588") + ctx.dropTempTable("test_SPARK8588") } test("SPARK-9371: fix the support for special chars in column names for hive context") { - TestHive.read.json(TestHive.sparkContext.makeRDD( + ctx.read.json(ctx.sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index deec0048d24b..3677aed98182 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + private val ctx = sqlContext + import ctx.implicits._ + import ctx._ - import sqlContext._ - import sqlContext.implicits._ + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 744d46293814..6d970b1ae719 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -21,11 +21,8 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind -import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -48,7 +45,10 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { +class OrcQuerySuite extends QueryTest with OrcTest { + private val ctx = sqlContext + import ctx.implicits._ + import ctx._ def getTempFilePath(prefix: String, suffix: String = ""): File = { val tempFile = File.createTempFile(prefix, suffix) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 145965388da0..9c37cc652b7c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -24,13 +24,19 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} -private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => - lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive +private[sql] trait OrcTest extends SparkFunSuite with SQLTestUtils with MyTestSQLContext { - import sqlContext.implicits._ - import sqlContext.sparkContext + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext + import ctx.implicits._ + import ctx.sparkContext + + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c4bc60086f6e..7bbfd3e98fad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -19,17 +19,13 @@ package org.apache.spark.sql.hive import java.io.File -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -58,6 +54,9 @@ case class ParquetDataWithKeyAndComplexTypes( * built in parquet support. */ class ParquetMetastoreSuite extends ParquetPartitioningTest { + private val ctx = sqlContext.asInstanceOf[TestHiveContext] + import ctx._ + override def beforeAll(): Unit = { super.beforeAll() dropTables("partitioned_parquet", @@ -536,6 +535,10 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { * A suite of tests for the Parquet support through the data sources API. */ class ParquetSourceSuite extends ParquetPartitioningTest { + private val ctx = sqlContext.asInstanceOf[TestHiveContext] + import ctx.implicits._ + import ctx._ + override def beforeAll(): Unit = { super.beforeAll() dropTables("partitioned_parquet", @@ -684,8 +687,16 @@ class ParquetSourceSuite extends ParquetPartitioningTest { /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def sqlContext: SQLContext = TestHive +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with MyTestSQLContext { + + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext + import ctx.implicits._ + import ctx._ + + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx var partitionedTableDir: File = null var normalTableDir: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index e976125b3706..c5b3028e95df 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -20,12 +20,19 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} -class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestHive +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils with MyTestSQLContext { + + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext + + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName @@ -35,7 +42,7 @@ class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { // Here we coalesce partition number to 1 to ensure that only a single task is issued. This // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` // directory while committing/aborting the job. See SPARK-8513 for more details. - val df = sqlContext.range(0, 10).coalesce(1) + val df = ctx.range(0, 10).coalesce(1) intercept[SparkException] { df.write.format(dataSourceName).save(file.getCanonicalPath) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index d280543a071d..c272fe6b66d9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -28,10 +28,11 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + private val ctx = sqlContext + import ctx.implicits._ + import ctx._ - import sqlContext._ - import sqlContext.implicits._ + override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index e8975e5f5cd0..68c5d028a70f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -23,9 +23,10 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + private val ctx = sqlContext + import ctx._ - import sqlContext._ + override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index dd274023a1cf..4d873e75e7df 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -28,16 +28,21 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ -abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override lazy val sqlContext: SQLContext = TestHive +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyTestSQLContext { - import sqlContext.sql - import sqlContext.implicits._ + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext + import ctx.implicits._ + import ctx.sql + + // For SQLTestUtils + protected override def _sqlContext: SQLContext = ctx val dataSourceName: String From 5e6a20c37562fcc792c56fcdf8e9df16a661daf1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 4 Aug 2015 18:17:44 -0700 Subject: [PATCH 02/39] Refactor SQLTestUtils to reduce duplication --- .../spark/sql/ColumnExpressionSuite.scala | 7 +- .../org/apache/spark/sql/DataFrameSuite.scala | 7 +- .../spark/sql/DataFrameTungstenSuite.scala | 7 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 7 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 7 +- .../spark/sql/execution/PlannerSuite.scala | 9 +- .../spark/sql/parquet/ParquetTest.scala | 9 +- .../spark/sql/test/MyTestSQLContext.scala | 4 +- .../apache/spark/sql/test/SQLTestUtils.scala | 10 +- .../spark/sql/hive/test/HiveTestUtils.scala | 32 ++++++ .../sql/hive/test/MyTestHiveContext.scala | 74 ++++++++++++ .../sql/hive/MetastoreDataSourcesSuite.scala | 22 ++-- .../spark/sql/hive/MultiDatabaseSuite.scala | 75 ++++++------- .../execution/AggregationQuerySuite.scala | 105 ++++++++---------- .../sql/hive/execution/SQLQuerySuite.scala | 13 +-- .../hive/orc/OrcHadoopFsRelationSuite.scala | 2 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 14 +-- .../apache/spark/sql/hive/orc/OrcTest.scala | 17 +-- .../apache/spark/sql/hive/parquetSuites.scala | 3 - .../CommitFailureTestRelationSuite.scala | 15 +-- .../ParquetHadoopFsRelationSuite.scala | 6 +- .../SimpleTextHadoopFsRelationSuite.scala | 2 +- .../sql/sources/hadoopFsRelationSuites.scala | 59 +++++----- 23 files changed, 273 insertions(+), 233 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/MyTestHiveContext.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 5f6581b23a44..17eacd18dd8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,16 +22,13 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { +class ColumnExpressionSuite extends QueryTest with SQLTestUtils { private val ctx = sqlContextWithData import ctx.implicits._ import ctx._ - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 77493997296d..954e66c29fc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,16 +28,13 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.json.JSONRelation import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} -class DataFrameSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { +class DataFrameSuite extends QueryTest with SQLTestUtils { private val ctx = sqlContextWithData import ctx.implicits._ import ctx._ - // For SQLTestUtils - protected override def _sqlContext = ctx - test("analysis error should be eagerly reported") { // Eager analysis. withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index c2e516aefe19..3ae70c03e392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ /** @@ -27,13 +27,10 @@ import org.apache.spark.sql.types._ * This is here for now so I can make sure Tungsten project is tested without refactoring existing * end-to-end test infra. In the long run this should just go away. */ -class DataFrameTungstenSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { +class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { private val ctx = sqlContext import ctx.implicits._ - // For SQLTestUtils - protected override def _sqlContext = ctx - test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d0b229219445..ed098e258a28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -26,20 +26,17 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with SQLTestUtils with MyTestSQLContext { +class SQLQuerySuite extends QueryTest with SQLTestUtils { private val ctx = sqlContextWithData import ctx.implicits._ import ctx._ - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7d379d8a5972..de395d4e1670 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,18 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.test.SQLTestUtils private case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { +class UDFSuite extends QueryTest with SQLTestUtils { private val ctx = sqlContextWithData import ctx.implicits._ import ctx._ - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2c7b93ccbf82..5f9828c1c068 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,20 +22,17 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.{execution, Row, SQLConf, SQLContext} +import org.apache.spark.sql.{execution, Row, SQLConf} -class PlannerSuite extends SparkFunSuite with SQLTestUtils with MyTestSQLContext { +class PlannerSuite extends SparkFunSuite with SQLTestUtils { private val ctx = sqlContextWithData import ctx.implicits._ import ctx.planner._ import ctx._ - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 9889c40236c9..e30e2f503e54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,8 +23,8 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} -import org.apache.spark.sql.{SQLContext, DataFrame, SaveMode} +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,10 +33,7 @@ import org.apache.spark.sql.{SQLContext, DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SparkFunSuite with SQLTestUtils with MyTestSQLContext { - - // For SQLTestUtils - protected override def _sqlContext: SQLContext = sqlContext +private[sql] trait ParquetTest extends SparkFunSuite with SQLTestUtils { /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala index 23f373596623..1a2dcba81648 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala @@ -86,7 +86,7 @@ private[spark] trait MyTestSQLContext extends SparkFunSuite with BeforeAndAfterA } /** - * Switch the [[SQLContext]] with the one provided. + * Switch to the provided [[SQLContext]]. * * This stops the underlying [[SparkContext]] and expects a new one to be created. * This is needed because only one [[SparkContext]] is allowed per JVM. @@ -112,7 +112,7 @@ private[spark] trait MyTestSQLContext extends SparkFunSuite with BeforeAndAfterA } protected override def afterAll(): Unit = { - super.afterAll() switchSQLContext(() => null) + super.afterAll() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index f141c15b99f7..5d3af085bd60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -26,7 +26,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils -trait SQLTestUtils { this: SparkFunSuite => +private[spark] trait SQLTestUtils + extends SparkFunSuite + with AbstractSQLTestUtils + with MyTestSQLContext { + + protected final override def _sqlContext = sqlContext +} + +private[spark] trait AbstractSQLTestUtils { this: SparkFunSuite => protected def _sqlContext: SQLContext protected def configuration = _sqlContext.sparkContext.hadoopConfiguration diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala new file mode 100644 index 000000000000..266282e228e6 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.AbstractSQLTestUtils + +/** + * This is analogous to [[org.apache.spark.sql.test.SQLTestUtils]] but for hive tests. + */ +private[spark] trait HiveTestUtils + extends SparkFunSuite + with AbstractSQLTestUtils + with MyTestHiveContext { + + protected final override def _sqlContext = hiveContext +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/MyTestHiveContext.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/MyTestHiveContext.scala new file mode 100644 index 000000000000..ef0fe2e2e27c --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/MyTestHiveContext.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite + + +/** + * A scalatest trait for test suites where all tests share a single + * [[org.apache.spark.sql.hive.HiveContext]]. + */ +private[spark] trait MyTestHiveContext extends SparkFunSuite with BeforeAndAfterAll { + + /** + * The [[TestHiveContext]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local mode + * with the default test configurations. + */ + private var _ctx: TestHiveContext = new TestHiveContext + + /** The [[TestHiveContext]] to use for all tests in this suite. */ + protected def hiveContext: TestHiveContext = _ctx + + /** + * Switch to the provided [[org.apache.spark.sql.hive.HiveContext]]. + * + * This stops the underlying [[org.apache.spark.SparkContext]] and expects a new one to + * be created. This is needed because only one [[org.apache.spark.SparkContext]] is + * allowed per JVM. + */ + protected def switchHiveContext(newContext: () => TestHiveContext): Unit = { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = newContext() + } + } + + /** + * Execute the given block of code with a custom [[TestHiveContext]]. + * At the end of the method, a [[TestHiveContext]] will be restored. + */ + protected def withHiveContext[T](newContext: () => TestHiveContext)(body: => T) { + switchHiveContext(newContext) + try { + body + } finally { + switchHiveContext(() => new TestHiveContext) + } + } + + protected override def afterAll(): Unit = { + switchHiveContext(() => null) + super.afterAll() + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 363cacfe7129..5a1aa079c485 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -28,9 +28,8 @@ import org.apache.spark.Logging import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} -import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -39,18 +38,13 @@ import org.apache.spark.util.Utils */ class MetastoreDataSourcesSuite extends QueryTest - with SQLTestUtils - with MyTestSQLContext + with HiveTestUtils with Logging { - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext.asInstanceOf[TestHiveContext] + private val ctx = hiveContext import ctx.implicits._ import ctx._ - protected override def _sqlContext: SQLContext = ctx - var jsonFilePath: String = _ override def beforeAll(): Unit = { @@ -842,17 +836,17 @@ class MetastoreDataSourcesSuite test("SPARK-8156:create table to specific database by 'use dbname' ") { val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") - sqlContext.sql("""create database if not exists testdb8156""") - sqlContext.sql("""use testdb8156""") + ctx.sql("""create database if not exists testdb8156""") + ctx.sql("""use testdb8156""") df.write .format("parquet") .mode(SaveMode.Overwrite) .saveAsTable("ttt3") checkAnswer( - sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), + ctx.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), Row("ttt3", false)) - sqlContext.sql("""use default""") - sqlContext.sql("""drop database if exists testdb8156 CASCADE""") + ctx.sql("""use default""") + ctx.sql("""drop database if exists testdb8156 CASCADE""") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 89a97ca54144..6321c12777d3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,40 +17,33 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} -import org.apache.spark.sql.{QueryTest, SaveMode, SQLContext} +import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.{QueryTest, SaveMode} -class MultiDatabaseSuite extends QueryTest with SQLTestUtils with MyTestSQLContext { - - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext +class MultiDatabaseSuite extends QueryTest with HiveTestUtils { + private val ctx = hiveContext import ctx.sql - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - - private val df = sqlContext.range(10).coalesce(1) + private val df = ctx.range(10).coalesce(1) test(s"saveAsTable() to non-default database - with USE - Overwrite") { withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) - checkAnswer(sqlContext.table("t"), df) + assert(ctx.tableNames().contains("t")) + checkAnswer(ctx.table("t"), df) } - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(ctx.tableNames(db).contains("t")) + checkAnswer(ctx.table(s"$db.t"), df) } } test(s"saveAsTable() to non-default database - without USE - Overwrite") { withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(ctx.tableNames(db).contains("t")) + checkAnswer(ctx.table(s"$db.t"), df) } } @@ -59,12 +52,12 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with MyTestSQLConte activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") df.write.mode(SaveMode.Append).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) - checkAnswer(sqlContext.table("t"), df.unionAll(df)) + assert(ctx.tableNames().contains("t")) + checkAnswer(ctx.table("t"), df.unionAll(df)) } - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + assert(ctx.tableNames(db).contains("t")) + checkAnswer(ctx.table(s"$db.t"), df.unionAll(df)) } } @@ -72,8 +65,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with MyTestSQLConte withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + assert(ctx.tableNames(db).contains("t")) + checkAnswer(ctx.table(s"$db.t"), df.unionAll(df)) } } @@ -81,10 +74,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with MyTestSQLConte withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) + assert(ctx.tableNames().contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + checkAnswer(ctx.table(s"$db.t"), df.unionAll(df)) } } } @@ -93,13 +86,13 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with MyTestSQLConte withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) + assert(ctx.tableNames().contains("t")) } - assert(sqlContext.tableNames(db).contains("t")) + assert(ctx.tableNames(db).contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + checkAnswer(ctx.table(s"$db.t"), df.unionAll(df)) } } @@ -107,10 +100,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with MyTestSQLConte withTempDatabase { db => activateDatabase(db) { sql("CREATE TABLE t (key INT)") - checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + checkAnswer(ctx.table("t"), ctx.emptyDataFrame) } - checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + checkAnswer(ctx.table(s"$db.t"), ctx.emptyDataFrame) } } @@ -118,21 +111,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with MyTestSQLConte withTempDatabase { db => activateDatabase(db) { sql(s"CREATE TABLE t (key INT)") - assert(sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames("default").contains("t")) + assert(ctx.tableNames().contains("t")) + assert(!ctx.tableNames("default").contains("t")) } - assert(!sqlContext.tableNames().contains("t")) - assert(sqlContext.tableNames(db).contains("t")) + assert(!ctx.tableNames().contains("t")) + assert(ctx.tableNames(db).contains("t")) activateDatabase(db) { sql(s"DROP TABLE t") - assert(!sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames("default").contains("t")) + assert(!ctx.tableNames().contains("t")) + assert(!ctx.tableNames("default").contains("t")) } - assert(!sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames(db).contains("t")) + assert(!ctx.tableNames().contains("t")) + assert(!ctx.tableNames(db).contains("t")) } } @@ -151,12 +144,12 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with MyTestSQLConte |LOCATION '$path' """.stripMargin) - checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + checkAnswer(ctx.table("t"), ctx.emptyDataFrame) df.write.parquet(s"$path/p=1") sql("ALTER TABLE t ADD PARTITION (p=1)") sql("REFRESH TABLE t") - checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + checkAnswer(ctx.table("t"), df.withColumn("p", lit(1))) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index c003e827bb16..b13c0cb82a23 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -18,27 +18,20 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf, SQLContext} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} -abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with MyTestSQLContext { - - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext +abstract class AggregationQuerySuite extends QueryTest with HiveTestUtils { + protected final val ctx = hiveContext import ctx.implicits._ - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - var originalUseAggregate2: Boolean = _ override def beforeAll(): Unit = { - originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") + originalUseAggregate2 = ctx.conf.useSqlAggregate2 + ctx.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), @@ -71,27 +64,27 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My (3, null, null)).toDF("key", "value1", "value2") data2.write.saveAsTable("agg2") - val emptyDF = sqlContext.createDataFrame( - sqlContext.sparkContext.emptyRDD[Row], + val emptyDF = ctx.createDataFrame( + ctx.sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) emptyDF.registerTempTable("emptyTable") // Register UDAFs - sqlContext.udaf.register("mydoublesum", new MyDoubleSum) - sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg) + ctx.udaf.register("mydoublesum", new MyDoubleSum) + ctx.udaf.register("mydoubleavg", new MyDoubleAvg) } override def afterAll(): Unit = { - sqlContext.sql("DROP TABLE IF EXISTS agg1") - sqlContext.sql("DROP TABLE IF EXISTS agg2") - sqlContext.dropTempTable("emptyTable") - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) + ctx.sql("DROP TABLE IF EXISTS agg1") + ctx.sql("DROP TABLE IF EXISTS agg2") + ctx.dropTempTable("emptyTable") + ctx.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } test("empty table") { // If there is no GROUP BY clause and the table is empty, we will generate a single row. checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | AVG(value), @@ -108,7 +101,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(null, 0, 0, 0, null, null, null, null, null) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | AVG(value), @@ -127,7 +120,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My // If there is a GROUP BY clause and the table is empty, there is no output. checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | AVG(value), @@ -147,7 +140,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("only do grouping") { checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT key |FROM agg1 @@ -156,7 +149,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT DISTINCT value1, key |FROM agg2 @@ -173,7 +166,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(null, null) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT value1, key |FROM agg2 @@ -193,7 +186,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("case in-sensitive resolution") { checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT avg(value), kEY - 100 |FROM agg1 @@ -202,7 +195,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT sum(distinct value1), kEY - 100, count(distinct value1) |FROM agg2 @@ -211,7 +204,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT valUe * key - 100 |FROM agg1 @@ -227,7 +220,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("test average no key in output") { checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT avg(value) |FROM agg1 @@ -238,7 +231,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("test average") { checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT key, avg(value) |FROM agg1 @@ -247,7 +240,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT avg(value), key |FROM agg1 @@ -256,7 +249,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT avg(value) + 1.5, key + 10 |FROM agg1 @@ -265,14 +258,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT avg(value) FROM agg1 """.stripMargin), Row(11.125) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT avg(null) """.stripMargin), @@ -281,7 +274,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("udaf") { checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | key, @@ -301,7 +294,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("non-AlgebraicAggregate aggreguate function") { checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT mydoublesum(value), key |FROM agg1 @@ -310,14 +303,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT mydoublesum(value) FROM agg1 """.stripMargin), Row(89.0) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT mydoublesum(null) """.stripMargin), @@ -326,7 +319,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT mydoublesum(value), key, avg(value) |FROM agg1 @@ -338,7 +331,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(30.0, null, 10.0) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | mydoublesum(value + 1.5 * key), @@ -358,7 +351,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("single distinct column set") { // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | min(distinct value1), @@ -371,7 +364,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(-60, 70.0, 101.0/9.0, 5.6, 100.0)) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | mydoubleavg(distinct value1), @@ -390,7 +383,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | key, @@ -410,7 +403,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("test count") { checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | count(value2), @@ -433,7 +426,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My Row(0, null, 1, 1, null) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | count(value2), @@ -460,7 +453,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My test("error handling") { withSQLConf("spark.sql.useAggregate2" -> "false") { val errorMessage = intercept[AnalysisException] { - sqlContext.sql( + ctx.sql( """ |SELECT | key, @@ -478,7 +471,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My // we can remove the following two tests. withSQLConf("spark.sql.useAggregate2" -> "true") { val errorMessage = intercept[AnalysisException] { - sqlContext.sql( + ctx.sql( """ |SELECT | key, @@ -491,7 +484,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with My assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) // This will fall back to the old aggregate - val newAggregateOperators = sqlContext.sql( + val newAggregateOperators = ctx.sql( """ |SELECT | key, @@ -515,14 +508,14 @@ class SortBasedAggregationQuerySuite extends AggregationQuerySuite { var originalUnsafeEnabled: Boolean = _ override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") + originalUnsafeEnabled = ctx.conf.unsafeEnabled + ctx.setConf(SQLConf.UNSAFE_ENABLED.key, "false") super.beforeAll() } override def afterAll(): Unit = { super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } @@ -531,13 +524,13 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite { var originalUnsafeEnabled: Boolean = _ override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + originalUnsafeEnabled = ctx.conf.unsafeEnabled + ctx.setConf(SQLConf.UNSAFE_ENABLED.key, "true") super.beforeAll() } override def afterAll(): Unit = { super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index e03c743ecc12..69d4335c762f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,10 +26,9 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ case class Nested1(f1: Nested2) @@ -62,17 +61,11 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest with SQLTestUtils with MyTestSQLContext { - - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext.asInstanceOf[TestHiveContext] +class SQLQuerySuite extends QueryTest with HiveTestUtils { + private val ctx = hiveContext import ctx.implicits._ import ctx._ - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - test("UDTF") { sql(s"ADD JAR ${getHiveFile("TestUDTF.jar").getCanonicalPath()}") // The function source code can be found at: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 3677aed98182..759d537041d0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { - private val ctx = sqlContext + private val ctx = hiveContext import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 6d970b1ae719..415363793b83 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -46,7 +46,7 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with OrcTest { - private val ctx = sqlContext + private val ctx = hiveContext import ctx.implicits._ import ctx._ @@ -63,7 +63,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { withOrcFile(data) { file => checkAnswer( - sqlContext.read.orc(file), + ctx.read.orc(file), data.toDF().collect()) } } @@ -293,7 +293,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { withTable("empty_orc") { withTempTable("empty", "single") { - sqlContext.sql( + ctx.sql( s"""CREATE TABLE empty_orc(key INT, value STRING) |STORED AS ORC |LOCATION '$path' @@ -304,13 +304,13 @@ class OrcQuerySuite extends QueryTest with OrcTest { // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because // Spark SQL ORC data source always avoids write empty ORC files. - sqlContext.sql( + ctx.sql( s"""INSERT INTO TABLE empty_orc |SELECT key, value FROM empty """.stripMargin) val errorMessage = intercept[AnalysisException] { - sqlContext.read.orc(path) + ctx.read.orc(path) }.getMessage assert(errorMessage.contains("Failed to discover schema from ORC files")) @@ -318,12 +318,12 @@ class OrcQuerySuite extends QueryTest with OrcTest { val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) singleRowDF.registerTempTable("single") - sqlContext.sql( + ctx.sql( s"""INSERT INTO TABLE empty_orc |SELECT key, value FROM single """.stripMargin) - val df = sqlContext.read.orc(path) + val df = ctx.read.orc(path) assert(df.schema === singleRowDF.schema.asNullable) checkAnswer(df, singleRowDF) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 9c37cc652b7c..2c946314f209 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -24,20 +24,13 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.hive.test.HiveTestUtils -private[sql] trait OrcTest extends SparkFunSuite with SQLTestUtils with MyTestSQLContext { - - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext +private[sql] trait OrcTest extends SparkFunSuite with HiveTestUtils { + private val ctx = hiveContext import ctx.implicits._ import ctx.sparkContext - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` * returns. @@ -58,7 +51,7 @@ private[sql] trait OrcTest extends SparkFunSuite with SQLTestUtils with MyTestSQ protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(sqlContext.read.orc(path))) + withOrcFile(data)(path => f(ctx.read.orc(path))) } /** @@ -70,7 +63,7 @@ private[sql] trait OrcTest extends SparkFunSuite with SQLTestUtils with MyTestSQ (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + ctx.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 7bbfd3e98fad..6084f594a2d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -695,9 +695,6 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with import ctx.implicits._ import ctx._ - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - var partitionedTableDir: File = null var normalTableDir: File = null var partitionedTableDirWithKey: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index c5b3028e95df..d7ebafc3b01e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -20,19 +20,10 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.hive.test.HiveTestUtils - -class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils with MyTestSQLContext { - - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext - - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx +class CommitFailureTestRelationSuite extends SparkFunSuite with HiveTestUtils { + private val ctx = hiveContext // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index c272fe6b66d9..830591ed54b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - private val ctx = sqlContext + private val ctx = hiveContext import ctx.implicits._ import ctx._ @@ -120,7 +120,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("SPARK-8604: Parquet data source should write summary file while doing appending") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(0, 5) + val df = ctx.range(0, 5) df.write.mode(SaveMode.Overwrite).parquet(path) val summaryPath = new Path(path, "_metadata") @@ -131,7 +131,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { fs.delete(commonSummaryPath, true) df.write.mode(SaveMode.Append).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + checkAnswer(ctx.read.parquet(path), df.unionAll(df)) assert(fs.exists(summaryPath)) assert(fs.exists(commonSummaryPath)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 68c5d028a70f..5c8beea3a975 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { - private val ctx = sqlContext + private val ctx = hiveContext import ctx._ override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 4d873e75e7df..dee39d8f7888 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -28,22 +28,15 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} +import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.sql.types._ -abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyTestSQLContext { - - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext +abstract class HadoopFsRelationTest extends QueryTest with HiveTestUtils { + private val ctx = hiveContext import ctx.implicits._ import ctx.sql - // For SQLTestUtils - protected override def _sqlContext: SQLContext = ctx - val dataSourceName: String val dataSchema = @@ -112,7 +105,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("path", file.getCanonicalPath) .option("dataSchema", dataSchema.json) .load(), @@ -126,7 +119,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath).orderBy("a"), testDF.unionAll(testDF).orderBy("a").collect()) @@ -146,7 +139,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT testDF.write.mode(SaveMode.Ignore).format(dataSourceName).save(file.getCanonicalPath) val path = new Path(file.getCanonicalPath) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val fs = path.getFileSystem(ctx.sparkContext.hadoopConfiguration) assert(fs.listStatus(path).isEmpty) } } @@ -160,7 +153,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .save(file.getCanonicalPath) checkQueries( - sqlContext.read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath)) } @@ -181,7 +174,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -203,7 +196,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.unionAll(partitionedTestDF).collect()) @@ -225,7 +218,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -261,7 +254,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), testDF.collect()) + checkAnswer(ctx.table("t"), testDF.collect()) } } @@ -270,7 +263,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + checkAnswer(ctx.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) } } @@ -289,7 +282,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT withTempTable("t") { testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") - assert(sqlContext.table("t").collect().isEmpty) + assert(ctx.table("t").collect().isEmpty) } } @@ -300,7 +293,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .saveAsTable("t") withTable("t") { - checkQueries(sqlContext.table("t")) + checkQueries(ctx.table("t")) } } @@ -320,7 +313,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) + checkAnswer(ctx.table("t"), partitionedTestDF.collect()) } } @@ -340,7 +333,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + checkAnswer(ctx.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) } } @@ -360,7 +353,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) + checkAnswer(ctx.table("t"), partitionedTestDF.collect()) } } @@ -409,7 +402,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .partitionBy("p1", "p2") .saveAsTable("t") - assert(sqlContext.table("t").collect().isEmpty) + assert(ctx.table("t").collect().isEmpty) } } @@ -421,7 +414,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .partitionBy("p1", "p2") .save(file.getCanonicalPath) - val df = sqlContext.read + val df = ctx.read .format(dataSourceName) .option("dataSchema", dataSchema.json) .load(s"${file.getCanonicalPath}/p1=*/p2=???") @@ -433,7 +426,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT s"${file.getCanonicalFile}/p1=2/p2=bar" ).map { p => val path = new Path(p) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val fs = path.getFileSystem(ctx.sparkContext.hadoopConfiguration) path.makeQualified(fs.getUri, fs.getWorkingDirectory).toString } @@ -461,7 +454,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .saveAsTable("t") withTempTable("t") { - checkAnswer(sqlContext.table("t"), input.collect()) + checkAnswer(ctx.table("t"), input.collect()) } } } @@ -476,7 +469,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) + checkAnswer(ctx.table("t"), df.select('b, 'c, 'a).collect()) } } @@ -488,7 +481,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT test("SPARK-8406: Avoids name collision while writing files") { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext + ctx .range(10000) .repartition(250) .write @@ -497,7 +490,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT .save(path) assertResult(10000) { - sqlContext + ctx .read .format(dataSourceName) .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) @@ -510,7 +503,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT test("SPARK-8578 specified custom output committer will not be used to append data") { val clonedConf = new Configuration(configuration) try { - val df = sqlContext.range(1, 10).toDF("i") + val df = ctx.range(1, 10).toDF("i") withTempPath { dir => df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) configuration.set( @@ -525,7 +518,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with MyT // with file format and AlwaysFailOutputCommitter will not be used. df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) checkAnswer( - sqlContext.read + ctx.read .format(dataSourceName) .option("dataSchema", df.schema.json) .load(dir.getCanonicalPath), From d1d1449f3c01c8bb08cc56e9169180313f187bf3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Aug 2015 10:30:31 -0700 Subject: [PATCH 03/39] Remove HiveTest singleton This allows us to use custom SparkContexts in hive tests. --- project/SparkBuild.scala | 5 +- sql/README.md | 1 - .../HashJoinCompatibilitySuite.scala | 5 +- .../execution/HiveCompatibilitySuite.scala | 20 +++--- .../HiveWindowFunctionQuerySuite.scala | 20 +++--- .../{TestHive.scala => TestHiveContext.scala} | 18 ------ .../spark/sql/hive/JavaDataFrameSuite.java | 4 +- .../hive/JavaMetastoreDataSourcesSuite.java | 4 +- .../spark/sql/hive/CachedTableSuite.scala | 21 ++++--- .../spark/sql/hive/ErrorPositionSuite.scala | 8 ++- .../hive/HiveDataFrameAnalyticsSuite.scala | 15 ++--- .../sql/hive/HiveDataFrameJoinSuite.scala | 6 +- .../sql/hive/HiveDataFrameWindowSuite.scala | 8 ++- .../sql/hive/HiveMetastoreCatalogSuite.scala | 10 +-- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 22 +++---- .../sql/hive/InsertIntoHiveTableSuite.scala | 52 ++++++++-------- .../spark/sql/hive/ListTablesSuite.scala | 15 ++--- .../hive/ParquetHiveCompatibilitySuite.scala | 22 ++++--- .../spark/sql/hive/QueryPartitionSuite.scala | 8 +-- .../spark/sql/hive/SerializationSuite.scala | 11 ++-- .../spark/sql/hive/StatisticsSuite.scala | 17 ++--- .../org/apache/spark/sql/hive/UDFSuite.scala | 9 ++- .../execution/BigDataBenchmarkSuite.scala | 12 ++-- .../hive/execution/ConcurrentHiveSuite.scala | 17 +++-- .../hive/execution/HiveComparisonTest.scala | 27 ++++---- .../sql/hive/execution/HiveExplainSuite.scala | 7 ++- .../HiveOperatorQueryableSuite.scala | 7 ++- .../sql/hive/execution/HivePlanTest.scala | 9 +-- .../sql/hive/execution/HiveQuerySuite.scala | 35 +++++------ .../hive/execution/HiveResolutionSuite.scala | 4 +- .../sql/hive/execution/HiveSerDeSuite.scala | 15 +++-- .../hive/execution/HiveTableScanSuite.scala | 24 ++++--- .../execution/HiveTypeCoercionSuite.scala | 3 +- .../sql/hive/execution/HiveUDFSuite.scala | 62 +++++++++---------- .../sql/hive/execution/PruningSuite.scala | 8 +-- .../execution/ScriptTransformationSuite.scala | 15 ++--- .../hive/orc/OrcPartitionDiscoverySuite.scala | 23 +++---- .../spark/sql/hive/orc/OrcSourceSuite.scala | 13 ++-- .../apache/spark/sql/hive/parquetSuites.scala | 16 ++--- 39 files changed, 295 insertions(+), 303 deletions(-) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/test/{TestHive.scala => TestHiveContext.scala} (96%) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ebbcd9a48243..15452470e5d5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -333,11 +333,12 @@ object SQL { object Hive { + // TODO: check me, will this work? lazy val settings = Seq( javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), - // Multiple queries rely on the TestHive singleton. See comments there for more details. + // TODO: re-enable this now that we've gotten rid of the TestHive singleton? parallelExecution in Test := false, // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. @@ -356,8 +357,6 @@ object Hive { |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.hive._ - |import org.apache.spark.sql.hive.test.TestHive._ - |import org.apache.spark.sql.hive.test.TestHive.implicits._ |import org.apache.spark.sql.types._""".stripMargin, cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce diff --git a/sql/README.md b/sql/README.md index 63d4dac9829e..4b8074d85585 100644 --- a/sql/README.md +++ b/sql/README.md @@ -60,7 +60,6 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.types._ Type in expressions to have them evaluated. Type :help for more information. diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala index 1a5ba20404c4..5fefce41a2ba 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.hive.execution import java.io.File import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution with hash joins. @@ -28,11 +27,11 @@ import org.apache.spark.sql.hive.test.TestHive class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) + ctx.setConf(SQLConf.SORTMERGE_JOIN, false) } override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) + ctx.setConf(SQLConf.SORTMERGE_JOIN, true) super.afterAll() } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c46a4a4b0be5..a54ed6e52916 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -23,41 +23,41 @@ import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { + // TODO: bundle in jar files... get from classpath - private lazy val hiveQueryDir = TestHive.getHiveFile( + private lazy val hiveQueryDir = ctx.getHiveFile( "ql/src/test/queries/clientpositive".split("/").mkString(File.separator)) private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault - private val originalColumnBatchSize = TestHive.conf.columnBatchSize - private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning + private val originalColumnBatchSize = ctx.conf.columnBatchSize + private val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) override def beforeAll() { - TestHive.cacheTables = true + ctx.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) // Set a relatively small column batch size for testing purposes - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) } override def afterAll() { - TestHive.cacheTables = false + ctx.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 92bb9e6d73af..234ec481e79c 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -22,8 +22,6 @@ import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.util.Utils /** @@ -33,12 +31,14 @@ import org.apache.spark.util.Utils * files, every `createQueryTest` calls should explicitly set `reset` to `false`. */ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfter { + import ctx._ + private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault private val testTempDir = Utils.createTempDir() override def beforeAll() { - TestHive.cacheTables = true + ctx.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -59,7 +59,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte | p_retailprice DOUBLE, | p_comment STRING) """.stripMargin) - val testData1 = TestHive.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + val testData1 = ctx.getHiveFile("data/files/part_tiny.txt").getCanonicalPath sql( s""" |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part @@ -83,7 +83,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |row format delimited |fields terminated by '|' """.stripMargin) - val testData2 = TestHive.getHiveFile("data/files/over1k").getCanonicalPath + val testData2 = ctx.getHiveFile("data/files/over1k").getCanonicalPath sql( s""" |LOAD DATA LOCAL INPATH '$testData2' overwrite into table over1k @@ -100,10 +100,10 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte } override def afterAll() { - TestHive.cacheTables = false + ctx.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - TestHive.reset() + ctx.reset() } ///////////////////////////////////////////////////////////////////////////// @@ -766,7 +766,7 @@ class HiveWindowFunctionQueryFileSuite private val testTempDir = Utils.createTempDir() override def beforeAll() { - TestHive.cacheTables = true + ctx.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -783,10 +783,10 @@ class HiveWindowFunctionQueryFileSuite } override def afterAll() { - TestHive.cacheTables = false + ctx.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - TestHive.reset() + ctx.reset() } override def blackList: Seq[String] = Seq( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala similarity index 96% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala index c659f4e8eb07..33e051363d20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala @@ -42,30 +42,12 @@ import org.apache.spark.{SparkConf, SparkContext} /* Implicit conversions */ import scala.collection.JavaConversions._ -// TODO: remove it -object TestHive - extends TestHiveContext( - new SparkContext( - System.getProperty("spark.sql.test.master", "local[32]"), - "TestSQLContext", - new SparkConf() - .set("spark.sql.test", "") - .set("spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe") - .set("spark.buffer.pageSize", "4m") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) - /** * A locally running test instance of Spark's Hive execution engine. * * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables. * Calling [[reset]] will delete all tables and other state in the database, leaving the database * in a "clean" state. - * - * TestHive is singleton object version of this class because instantiating multiple copies of the - * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of - * test cases that rely on TestHive must be serialized. */ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 741a3cd31c60..395cfeccff6a 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Window; import org.apache.spark.sql.hive.HiveContext; -import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.hive.test.TestHiveContext; public class JavaDataFrameSuite { private transient JavaSparkContext sc; @@ -47,7 +47,7 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - hc = TestHive$.MODULE$; + hc = new TestHiveContext(); sc = new JavaSparkContext(hc.sparkContext()); List jsonObjects = new ArrayList(10); diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 15c2c3deb0d8..30a35a6c06f1 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -38,7 +38,7 @@ import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; import org.apache.spark.sql.hive.HiveContext; -import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.hive.test.TestHiveContext; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -63,7 +63,7 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - sqlContext = TestHive$.MODULE$; + sqlContext = new TestHiveContext(); sc = new JavaSparkContext(sqlContext.sparkContext()); originalDefaultSource = sqlContext.conf().defaultDataSourceName(); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 39d315aaeab5..0ebf7f1bd855 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} +import org.apache.spark.sql.{SaveMode, AnalysisException, QueryTest} +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.hive.test.MyTestHiveContext import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest { +class CachedTableSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx._ def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan @@ -95,18 +96,18 @@ class CachedTableSuite extends QueryTest { test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - TestHive.uncacheTable("src") + uncacheTable("src") } } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - TestHive.sql("CACHE TABLE src") + sql("CACHE TABLE src") assertCached(table("src")) - assert(TestHive.isCached("src"), "Table 'src' should be cached") + assert(isCached("src"), "Table 'src' should be cached") - TestHive.sql("UNCACHE TABLE src") + sql("UNCACHE TABLE src") assertCached(table("src"), 0) - assert(!TestHive.isCached("src"), "Table 'src' should not be cached") + assert(!isCached("src"), "Table 'src' should not be cached") } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 30f5313d2b81..148ab31e91b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,12 +22,14 @@ import scala.util.Try import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.MyTestHiveContext import org.apache.spark.sql.{AnalysisException, QueryTest} -class ErrorPositionSuite extends QueryTest with BeforeAndAfter { +class ErrorPositionSuite extends QueryTest with BeforeAndAfter with MyTestHiveContext { + private val ctx = hiveContext + import ctx.implicits._ + import ctx._ before { Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index fb10f8583da9..352b883a5e75 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -19,24 +19,25 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.hive.test.MyTestHiveContext // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { +class HiveDataFrameAnalyticsSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx.implicits._ + import ctx._ + private var testData: DataFrame = _ override def beforeAll() { testData = Seq((1, 2), (2, 4)).toDF("a", "b") - TestHive.registerDataFrameAsTable(testData, "mytable") + registerDataFrameAsTable(testData, "mytable") } override def afterAll(): Unit = { - TestHive.dropTempTable("mytable") + dropTempTable("mytable") } test("rollup") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index 52e782768cb7..639bb66841b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.MyTestHiveContext -class HiveDataFrameJoinSuite extends QueryTest { +class HiveDataFrameJoinSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx.implicits._ // We should move this into SQL package if we make case sensitivity configurable in SQL. test("join - self join auto resolve ambiguity with case insensitivity") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index c177cbdd991c..49fb287075f8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.MyTestHiveContext -class HiveDataFrameWindowSuite extends QueryTest { +class HiveDataFrameWindowSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx.implicits._ + import ctx._ test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 983c013bcf86..9c3ae9f64076 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.sql.hive import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.MyTestHiveContext import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { +class HiveMetastoreCatalogSuite extends SparkFunSuite with MyTestHiveContext with Logging { + private val ctx = hiveContext test("struct field should accept underscore in sub-column name") { val metastr = "struct" - val datatype = HiveMetastoreTypes.toDataType(metastr) assert(datatype.isInstanceOf[StructType]) } @@ -39,8 +39,8 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { } test("duplicated metastore relations") { - import TestHive.implicits._ - val df = TestHive.sql("SELECT * FROM src") + import ctx.implicits._ + val df = ctx.sql("SELECT * FROM src") logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index b8d41065d3f0..ef3d32c3684d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -22,14 +22,14 @@ import java.io.File import scala.collection.mutable.ArrayBuffer import scala.sys.process.{ProcessLogger, Process} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.hive.test.{TestHiveContext, MyTestHiveContext} import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.Matchers -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ /** * This suite tests spark-submit with applications using HiveContext. @@ -38,20 +38,20 @@ class HiveSparkSubmitSuite extends SparkFunSuite with Matchers with ResetSystemProperties - with Timeouts { + with Timeouts + with MyTestHiveContext { - // TODO: rewrite these or mark them as slow tests to be run sparingly + private val ctx = hiveContext + import ctx._ - def beforeAll() { - System.setProperty("spark.testing", "true") - } + // TODO: rewrite these or mark them as slow tests to be run sparingly test("SPARK-8368: includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() + val jar4 = getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d33e81227db8..bd931c264840 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -22,62 +22,60 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.hive.test.MyTestHiveContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -/* Implicits */ -import org.apache.spark.sql.hive.test.TestHive._ - case class TestData(key: Int, value: String) case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.hive.test.TestHive.implicits._ - +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with MyTestHiveContext { + private val ctx = hiveContext + import ctx.implicits._ + import ctx._ - val testData = TestHive.sparkContext.parallelize( + private val _testData = sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { // Since every we are doing tests for DDL statements, // it is better to reset before every test. - TestHive.reset() + reset() // Register the testData, which will be used in every test. - testData.registerTempTable("testData") + _testData.registerTempTable("testData") } test("insertInto() HiveTable") { sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. - testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") + _testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq + _testData.collect().toSeq ) // Add more data. - testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") + _testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq + _testData.toDF().collect().toSeq ++ _testData.toDF().collect().toSeq ) // Now overwrite. - testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") + _testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") // Make sure the registered table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq + _testData.collect().toSeq ) } @@ -96,9 +94,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -169,8 +167,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert ArrayType.containsNull == false") { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val rowRDD = sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) + val df = createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -185,9 +183,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert MapType.valueContainsNull == false") { val schema = StructType(Seq( StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -202,9 +200,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert StructType.fields.exists(_.nullable == false)") { val schema = StructType(Seq( StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") @@ -217,11 +215,11 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { } test("SPARK-5498:partition schema does not match table schema") { - val testData = TestHive.sparkContext.parallelize( + val testData = sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") - val testDatawithNull = TestHive.sparkContext.parallelize( + val testDatawithNull = sparkContext.parallelize( (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 1c15997ea8e6..47281170c691 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.MyTestHiveContext -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.Row - -class ListTablesSuite extends QueryTest with BeforeAndAfterAll { - - import org.apache.spark.sql.hive.test.TestHive.implicits._ +class ListTablesSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx.implicits._ + import ctx._ val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index f00d3754c364..c967744bcdf1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -19,14 +19,16 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.parquet.ParquetCompatibilityTest -import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.sql.{Row, SQLConf} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { import ParquetCompatibilityTest.makeNullable - override val sqlContext: SQLContext = TestHive + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext /** * Set the staging directory (and hence path to ignore Parquet files under) @@ -40,7 +42,7 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { withTempTable("data") { - sqlContext.sql( + ctx.sql( s"""CREATE TABLE parquet_compat( | bool_column BOOLEAN, | byte_column TINYINT, @@ -57,16 +59,16 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { |LOCATION '${parquetStore.getCanonicalPath}' """.stripMargin) - val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") - sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + val schema = ctx.table("parquet_compat").schema + val rowRDD = ctx.sparkContext.parallelize(makeRows).coalesce(1) + ctx.createDataFrame(rowRDD, schema).registerTempTable("data") + ctx.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") } } } override protected def afterAll(): Unit = { - sqlContext.sql("DROP TABLE parquet_compat") + ctx.sql("DROP TABLE parquet_compat") } test("Read Parquet file generated by parquet-hive") { @@ -78,7 +80,7 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. // Have to assume all BINARY values are strings here. withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), makeRows) + checkAnswer(ctx.read.parquet(parquetStore.getCanonicalPath), makeRows) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 017bc2adc103..2a6c9800afcb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.QueryTest import org.apache.spark.util.Utils -class QueryPartitionSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive +class QueryPartitionSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext import ctx.implicits._ import ctx.sql diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index 93dcb10f7a29..792f8c595779 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.hive import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.hive.test.MyTestHiveContext -class SerializationSuite extends SparkFunSuite { +class SerializationSuite extends SparkFunSuite with MyTestHiveContext { + private val ctx = hiveContext test("[SPARK-5840] HiveContext should be serializable") { - val hiveContext = org.apache.spark.sql.hive.test.TestHive - hiveContext.hiveconf + ctx.hiveconf val serializer = new JavaSerializer(new SparkConf()).newInstance() - val bytes = serializer.serialize(hiveContext) - val deSer = serializer.deserialize[AnyRef](bytes) + val bytes = serializer.serialize(ctx) + serializer.deserialize[AnyRef](bytes) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e4fec7e2c8a2..cbe8b27ad10c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -17,25 +17,20 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterAll - import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.hive.test.MyTestHiveContext -class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - - private lazy val ctx: HiveContext = { - val ctx = org.apache.spark.sql.hive.test.TestHive - ctx.reset() - ctx.cacheTables = false - ctx - } - +class StatisticsSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext import ctx.sql + ctx.reset() + ctx.cacheTables = false + test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { val parsed = HiveQl.parseSql(analyzeCommand) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 9b3ede43ee2d..454b973e81d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.MyTestHiveContext case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ +class UDFSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index a3f5921a0cb2..115a9308f98a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -19,17 +19,19 @@ package org.apache.spark.sql.hive.execution import java.io.File -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.MyTestHiveContext /** * A set of test cases based on the big-data-benchmark. * https://amplab.cs.berkeley.edu/benchmark/ */ -class BigDataBenchmarkSuite extends HiveComparisonTest { - val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") +class BigDataBenchmarkSuite extends HiveComparisonTest with MyTestHiveContext { + import ctx._ - val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath - val testTables = Seq( + private val testDataDirectory = + new File("target" + File.separator + "big-data-benchmark-testdata") + private val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath + private val testTables = Seq( TestTable( "rankings", s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index b0d3dd44daed..d54dd7de5751 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,19 +17,24 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.BeforeAndAfterAll class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => - val ts = - new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf())) - ts.executeSql("SHOW TABLES").toRdd.collect() - ts.executeSql("SELECT * FROM src").toRdd.collect() - ts.executeSql("SHOW TABLES").toRdd.collect() + val sc = new SparkContext("local", s"TestSQLContext$i", new SparkConf()) + try { + val ts = new TestHiveContext(sc) + ts.executeSql("SHOW TABLES").toRdd.collect() + ts.executeSql("SELECT * FROM src").toRdd.collect() + ts.executeSql("SHOW TABLES").toRdd.collect() + } finally { + sc.stop() + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 638b9c810372..bbb6c1668f62 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution import java.io._ -import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} +import org.scalatest.GivenWhenThen import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.MyTestHiveContext /** * Allows the creations of tests that execute the same query against both hive @@ -40,7 +40,12 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite + with GivenWhenThen + with MyTestHiveContext + with Logging { + + protected val ctx = hiveContext /** * When set, any cache files that result in test failures will be deleted. Used when the test @@ -129,7 +134,7 @@ abstract class HiveComparisonTest } protected def prepareAnswer( - hiveQuery: TestHive.type#QueryExecution, + hiveQuery: ctx.type#QueryExecution, answer: Seq[String]): Seq[String] = { def isSorted(plan: LogicalPlan): Boolean = plan match { @@ -269,7 +274,7 @@ abstract class HiveComparisonTest try { if (reset) { - TestHive.reset() + ctx.reset() } val hiveCacheFiles = queryList.zipWithIndex.map { @@ -298,7 +303,7 @@ abstract class HiveComparisonTest hiveCachedResults } else { - val hiveQueries = queryList.map(new TestHive.QueryExecution(_)) + val hiveQueries = queryList.map(new ctx.QueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. // Note this must only look at the logical plan as we might not be able to analyze if // other DDL has not been executed yet. @@ -318,7 +323,7 @@ abstract class HiveComparisonTest case _: ExplainCommand => // No need to execute EXPLAIN queries as we don't check the output. Nil - case _ => TestHive.runSqlHive(queryString) + case _ => ctx.runSqlHive(queryString) } // We need to add a new line to non-empty answers so we can differentiate Seq() @@ -341,14 +346,14 @@ abstract class HiveComparisonTest fail(errorMessage) } }.toSeq - if (reset) { TestHive.reset() } + if (reset) { ctx.reset() } computedResults } // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.QueryExecution(queryString) + val query = new ctx.QueryExecution(queryString) try { (query, prepareAnswer(query, query.stringResult())) } catch { case e: Throwable => val errorMessage = @@ -408,8 +413,8 @@ abstract class HiveComparisonTest // okay by running a simple query. If this fails then we halt testing since // something must have gone seriously wrong. try { - new TestHive.QueryExecution("SELECT key FROM src").stringResult() - TestHive.runSqlHive("SELECT key FROM src") + new ctx.QueryExecution("SELECT key FROM src").stringResult() + ctx.runSqlHive("SELECT key FROM src") } catch { case e: Exception => logError(s"FATAL ERROR: Canary query threw $e This implies that the " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 697211222b90..6d7078283383 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.MyTestHiveContext /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest { +class HiveExplainSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx._ + test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, "== Physical Plan ==") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index efbef68cd444..ab0ac4b5b094 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.MyTestHiveContext /** * A set of tests that validates commands can also be queried by like a table */ -class HiveOperatorQueryableSuite extends QueryTest { +class HiveOperatorQueryableSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx._ + test("SPARK-5324 query result of describe command") { loadTestTable("src") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index ba56a8a6b689..84bb1d48b0a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -21,11 +21,12 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.MyTestHiveContext -class HivePlanTest extends QueryTest { - import TestHive._ - import TestHive.implicits._ +class HivePlanTest extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx.implicits._ + import ctx._ test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a7cfac51cc09..db70491f8ee6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,9 +30,7 @@ import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.MyTestHiveContext case class TestData(a: Int, b: String) @@ -40,14 +38,15 @@ case class TestData(a: Int, b: String) * A set of test cases expressed in Hive QL that are not covered by the tests * included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { +class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with MyTestHiveContext { + import ctx.implicits._ + import ctx._ + private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault - import org.apache.spark.sql.hive.test.TestHive.implicits._ - override def beforeAll() { - TestHive.cacheTables = true + cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -55,7 +54,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } override def afterAll() { - TestHive.cacheTables = false + cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION udtf_count2") @@ -623,7 +622,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("case sensitivity: registered table") { val testData = - TestHive.sparkContext.parallelize( + sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) testData.toDF().registerTempTable("REGisteredTABle") @@ -645,20 +644,20 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val df = sql("explain select key, count(value) from src group by key") assert(isExplanation(df)) - TestHive.reset() + reset() } test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") + sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") val results = sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() .map(x => Pair(x.getString(0), x.getInt(1))) assert(results === Array(Pair("foo", 4))) - TestHive.reset() + reset() } test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { @@ -708,7 +707,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key"))) - TestHive.reset() + reset() } test("Exactly once semantics for DDL and command statements") { @@ -796,7 +795,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Describe a registered temporary table. val testData = - TestHive.sparkContext.parallelize( + sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) testData.toDF().registerTempTable("test_describe_commands2") @@ -823,7 +822,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } test("ADD JAR command") { - val testJar = TestHive.getHiveFile("data/files/TestSerDe.jar").getCanonicalPath + val testJar = getHiveFile("data/files/TestSerDe.jar").getCanonicalPath sql("CREATE TABLE alter1(a INT, b INT)") intercept[Exception] { sql( @@ -836,8 +835,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("ADD JAR command 2") { // this is a test case from mapjoin_addjar.q - val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath - val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath + val testJar = getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val testData = getHiveFile("data/files/sample.json").getCanonicalPath sql(s"ADD JAR $testJar") sql( """CREATE TABLE t1(a string, b string) @@ -848,7 +847,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } test("ADD FILE command") { - val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile + val testFile = getHiveFile("data/files/v1.txt").getCanonicalFile sql(s"ADD FILE $testFile") val checkAddFileRDD = sparkContext.parallelize(1 to 2, 1).mapPartitions { _ => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index b08db6de2d2f..915256c5d3cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, jsonRDD, sql} -import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) @@ -29,6 +27,8 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) * included in the hive distribution. */ class HiveResolutionSuite extends HiveComparisonTest { + import ctx.implicits._ + import ctx._ test("SPARK-3698: case insensitive test for nested data") { read.json(sparkContext.makeRDD( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 5586a793618b..4c7d464d9bc0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -17,19 +17,18 @@ package org.apache.spark.sql.hive.execution -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.MyTestHiveContext /** * A set of tests that validates support for Hive SerDe. */ -class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { +class HiveSerDeSuite extends HiveComparisonTest with MyTestHiveContext { + import org.apache.hadoop.hive.serde2.RegexSerDe + import ctx._ + override def beforeAll(): Unit = { - import TestHive._ - import org.apache.hadoop.hive.serde2.RegexSerDe - super.beforeAll() - TestHive.cacheTables = false + super.beforeAll() + ctx.cacheTables = false sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 2209fc2f30a3..dc2f56cb3369 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ - import org.apache.spark.util.Utils class HiveTableScanSuite extends HiveComparisonTest { + import ctx.implicits._ + import ctx._ createQueryTest("partition_based_table_scan_with_different_serde", """ @@ -56,15 +54,15 @@ class HiveTableScanSuite extends HiveComparisonTest { """.stripMargin) test("Spark-4041: lowercase issue") { - TestHive.sql("CREATE TABLE tb (KEY INT, VALUE STRING) STORED AS ORC") - TestHive.sql("insert into table tb select key, value from src") - TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() - TestHive.sql("drop table tb") + sql("CREATE TABLE tb (KEY INT, VALUE STRING) STORED AS ORC") + sql("insert into table tb select key, value from src") + sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() + sql("drop table tb") } test("Spark-4077: timestamp query for null value") { - TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") - TestHive.sql( + sql("DROP TABLE IF EXISTS timestamp_query_null") + sql( """ CREATE EXTERNAL TABLE timestamp_query_null (time TIMESTAMP,id INT) ROW FORMAT DELIMITED @@ -74,10 +72,10 @@ class HiveTableScanSuite extends HiveComparisonTest { val location = Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() - TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") - assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() + sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") + assert(sql("SELECT time from timestamp_query_null limit 2").collect() === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) - TestHive.sql("DROP TABLE timestamp_query_null") + sql("DROP TABLE timestamp_query_null") } test("Spark-4959 Attributes are case sensitive when using a select query from a projection") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 197e9bfb02c4..4ce21258d875 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.Project -import org.apache.spark.sql.hive.test.TestHive /** * A set of tests that validate type promotion and coercion rules. @@ -43,7 +42,7 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + val project = ctx.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 7069afc9f7da..571c55d8d23c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -21,6 +21,8 @@ import java.io.{DataInput, DataOutput} import java.util import java.util.Properties +import scala.collection.JavaConversions._ + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject @@ -28,13 +30,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} +import org.apache.spark.sql.hive.test.MyTestHiveContext import org.apache.spark.util.Utils -import scala.collection.JavaConversions._ - case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) // Case classes for the custom UDF's. @@ -46,10 +46,10 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest { - - import TestHive.{udf, sql} - import TestHive.implicits._ +class HiveUDFSuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx.implicits._ + import ctx._ test("spark sql udf test that returns a struct") { udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -126,12 +126,12 @@ class HiveUDFSuite extends QueryTest { | "value", value)).value FROM src """.stripMargin), Seq(Row("val_0"))) } - val codegenDefault = TestHive.getConf(SQLConf.CODEGEN_ENABLED) - TestHive.setConf(SQLConf.CODEGEN_ENABLED, true) + val codegenDefault = getConf(SQLConf.CODEGEN_ENABLED) + setConf(SQLConf.CODEGEN_ENABLED, true) testOrderInStruct() - TestHive.setConf(SQLConf.CODEGEN_ENABLED, false) + setConf(SQLConf.CODEGEN_ENABLED, false) testOrderInStruct() - TestHive.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) } test("SPARK-6409 UDAFAverage test") { @@ -140,7 +140,7 @@ class HiveUDFSuite extends QueryTest { sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), Seq(Row(1.0, 260.182))) sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") - TestHive.reset() + reset() } test("SPARK-2693 udaf aggregates test") { @@ -160,7 +160,7 @@ class HiveUDFSuite extends QueryTest { } test("UDFIntegerToString") { - val testData = TestHive.sparkContext.parallelize( + val testData = sparkContext.parallelize( IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") @@ -171,11 +171,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") - TestHive.reset() + reset() } test("UDFToListString") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") @@ -186,11 +186,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") - TestHive.reset() + reset() } test("UDFToListInt") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") @@ -201,11 +201,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") - TestHive.reset() + reset() } test("UDFToStringIntMap") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + @@ -217,11 +217,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") - TestHive.reset() + reset() } test("UDFToIntIntMap") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + @@ -233,11 +233,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") - TestHive.reset() + reset() } test("UDFListListInt") { - val testData = TestHive.sparkContext.parallelize( + val testData = sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() @@ -249,11 +249,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") - TestHive.reset() + reset() } test("UDFListString") { - val testData = TestHive.sparkContext.parallelize( + val testData = sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() testData.registerTempTable("listStringTable") @@ -264,11 +264,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") - TestHive.reset() + reset() } test("UDFStringString") { - val testData = TestHive.sparkContext.parallelize( + val testData = sparkContext.parallelize( StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") @@ -278,11 +278,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row("hello world"), Row("hello goodbye"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") - TestHive.reset() + reset() } test("UDFTwoListList") { - val testData = TestHive.sparkContext.parallelize( + val testData = sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: @@ -295,7 +295,7 @@ class HiveUDFSuite extends QueryTest { Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - TestHive.reset() + reset() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 3bf8f3ac2048..563251ee4ba0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.execution import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.hive.test.TestHive - /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -28,12 +26,12 @@ import scala.collection.JavaConversions._ * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { - TestHive.cacheTables = false + ctx.cacheTables = false // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset // the environment to ensure all referenced tables in this suites are not cached in-memory. // Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. - TestHive.reset() + ctx.reset() // Column pruning tests @@ -145,7 +143,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { expectedScannedColumns: Seq[String], expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { - val plan = new TestHive.QueryExecution(sql).executedPlan + val plan = new ctx.QueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 0875232aede3..ab1d5725dba1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -22,16 +22,17 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.types.StringType class ScriptTransformationSuite extends SparkPlanTest { - override def sqlContext: SQLContext = TestHive + // Use a hive context instead + switchSQLContext(() => new TestHiveContext) + private val ctx = sqlContext.asInstanceOf[TestHiveContext] private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, @@ -58,7 +59,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = noSerdeIOSchema - )(TestHive), + )(ctx), rowsDf.collect()) } @@ -72,7 +73,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = serdeIOSchema - )(TestHive), + )(ctx), rowsDf.collect()) } @@ -87,7 +88,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = noSerdeIOSchema - )(TestHive), + )(ctx), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) @@ -104,7 +105,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = serdeIOSchema - )(TestHive), + )(ctx), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index a46ca9a2c970..0b0f821a5371 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -18,19 +18,16 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.util.Utils -import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.util.Utils + // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -38,7 +35,11 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { +class OrcPartitionDiscoverySuite extends QueryTest with MyTestHiveContext { + private val ctx = hiveContext + import ctx.implicits._ + import ctx._ + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal def withTempDir(f: File => Unit): Unit = { @@ -58,7 +59,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally TestHive.dropTempTable(tableName) + try f finally ctx.dropTempTable(tableName) } protected def makePartitionDir( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 82e08caf4645..3b809ba66201 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -19,14 +19,16 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.MyTestHiveContext case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { +abstract class OrcSuite extends QueryTest with MyTestHiveContext { + protected val ctx = hiveContext + import ctx.implicits._ + import ctx._ + var orcTableDir: File = null var orcTableAsDir: File = null @@ -41,7 +43,6 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { orcTableDir = File.createTempFile("orctests", "sparksql") orcTableDir.delete() orcTableDir.mkdir() - import org.apache.spark.sql.hive.test.TestHive.implicits._ sparkContext .makeRDD(1 to 10) @@ -124,6 +125,8 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { } class OrcSourceSuite extends OrcSuite { + import ctx._ + override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 6084f594a2d6..fb40cd8e7be8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.hive.test.{HiveTestUtils, TestHiveContext} import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.test.{SQLTestUtils, MyTestSQLContext} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -54,7 +53,7 @@ case class ParquetDataWithKeyAndComplexTypes( * built in parquet support. */ class ParquetMetastoreSuite extends ParquetPartitioningTest { - private val ctx = sqlContext.asInstanceOf[TestHiveContext] + private val ctx = hiveContext import ctx._ override def beforeAll(): Unit = { @@ -535,7 +534,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { * A suite of tests for the Parquet support through the data sources API. */ class ParquetSourceSuite extends ParquetPartitioningTest { - private val ctx = sqlContext.asInstanceOf[TestHiveContext] + private val ctx = hiveContext import ctx.implicits._ import ctx._ @@ -632,7 +631,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { """.stripMargin) checkAnswer( - sqlContext.read.parquet(path), + ctx.read.parquet(path), Row("1st", "2nd", Seq(Row("val_a", "val_b")))) } } @@ -687,11 +686,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest { /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with MyTestSQLContext { - - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext +abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { + private val ctx = hiveContext import ctx.implicits._ import ctx._ From d4aafb16c4201fc31cc0875906f751eee50a9dc1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Aug 2015 11:15:03 -0700 Subject: [PATCH 04/39] Avoid the need to switch to HiveContexts This is a clean up to refactor helper test traits and abstract classes in such a way that is accessible to hive tests. --- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../spark/sql/ColumnExpressionSuite.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 2 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../spark/sql/execution/SparkPlanTest.scala | 31 +++++++++++------- .../sql/execution/debug/DebuggingSuite.scala | 2 +- .../org/apache/spark/sql/json/JsonSuite.scala | 2 +- .../parquet/ParquetCompatibilityTest.scala | 19 ++++++++++- .../spark/sql/parquet/ParquetTest.scala | 20 ++++++++---- .../spark/sql/test/MyTestSQLContext.scala | 25 +++++---------- .../apache/spark/sql/test/SQLTestUtils.scala | 6 ++++ .../test/HiveParquetCompatibilityTest.scala | 28 ++++++++++++++++ .../spark/sql/hive/test/HiveParquetTest.scala | 26 +++++++++++++++ .../sql/hive/test/HiveSparkPlanTest.scala | 32 +++++++++++++++++++ .../spark/sql/hive/test/HiveTestUtils.scala | 2 +- .../spark/sql/hive/HiveParquetSuite.scala | 10 ++---- .../hive/ParquetHiveCompatibilitySuite.scala | 8 ++--- .../execution/ScriptTransformationSuite.scala | 11 +++---- 26 files changed, 175 insertions(+), 71 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index fb012e2ec3b3..292e88a3187e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.test.MyTestSQLContext private case class BigData(s: String) class CachedTableSuite extends QueryTest with MyTestSQLContext { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 17eacd18dd8d..695d7c284b18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.SQLTestUtils class ColumnExpressionSuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 1b7d4f60c9f1..8bbfc92dac72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types.DecimalType class DataFrameAggregateSuite extends QueryTest with MyTestSQLContext { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index cf8058b17fd0..60ef6ac059fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ * Test suite for functions in [[org.apache.spark.sql.functions]]. */ class DataFrameFunctionsSuite extends QueryTest with MyTestSQLContext { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 2a628d2b1345..698211dfdab2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.MyTestSQLContext class DataFrameJoinSuite extends QueryTest with MyTestSQLContext { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 954e66c29fc4..7d744e11fa25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} class DataFrameSuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a15a2a5a6e0d..7d51e68cd1a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.test.MyTestSQLContext class JoinSuite extends QueryTest with BeforeAndAfterEach with MyTestSQLContext { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ed098e258a28..a2296f881c31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types._ class MyDialect extends DefaultParserDialect class SQLQuerySuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index de395d4e1670..2da7874d9d45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.test.SQLTestUtils private case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 6ee417d69fe4..6bf25110e6a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest with MyTestSQLContext { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 7c3754f84595..e65d5e3d22c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.MyTestSQLContext class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with MyTestSQLContext { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5f9828c1c068..5be2d1e9e1a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.{execution, Row, SQLConf} class PlannerSuite extends SparkFunSuite with SQLTestUtils { - private val ctx = sqlContextWithData + private val ctx = sqlContext import ctx.implicits._ import ctx.planner._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 4419f09e6cd6..1b4994057239 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -27,18 +27,25 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.MyTestSQLContext - /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -abstract class SparkPlanTest extends SparkFunSuite with MyTestSQLContext { +private[sql] abstract class SparkPlanTest extends AbstractSparkPlanTest with MyTestSQLContext { + protected override def _sqlContext: SQLContext = sqlContext +} + +/** + * Helper class for testing individual physical operators with a pluggable [[SQLContext]]. + */ +private[sql] abstract class AbstractSparkPlanTest extends SparkFunSuite { + protected def _sqlContext: SQLContext /** * Creates a DataFrame from a local Seq of Product. */ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - sqlContext.implicits.localSeqToDataFrameHolder(data) + _sqlContext.implicits.localSeqToDataFrameHolder(data) } /** @@ -99,7 +106,7 @@ abstract class SparkPlanTest extends SparkFunSuite with MyTestSQLContext { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -123,7 +130,7 @@ abstract class SparkPlanTest extends SparkFunSuite with MyTestSQLContext { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -150,13 +157,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + _sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, sqlContext) + executePlan(expectedOutputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -171,7 +178,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -211,12 +218,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + _sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -279,10 +286,10 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.prepareForExecution.execute( + val resolvedPlan = _sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 12ef64252a7b..0bb7f6e0b8e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.MyTestSQLContext class DebuggingSuite extends SparkFunSuite with MyTestSQLContext { - private val ctx = sqlContextWithData + private val ctx = sqlContext test("DataFrame.debug()") { ctx.testData.debug() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 7405cef320df..037dcc738b94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { - private val _ctx = sqlContextWithData + private val _ctx = sqlContext import _ctx.implicits._ import _ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala index 1c8a1c4def24..ade090989ab1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala @@ -16,10 +16,12 @@ */ package org.apache.spark.sql.parquet + import java.io.File import scala.collection.JavaConversions._ +import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType @@ -27,7 +29,22 @@ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.QueryTest import org.apache.spark.util.Utils -abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { +/** + * Helper class for testing Parquet compatibility. + */ +private[sql] abstract class ParquetCompatibilityTest + extends AbstractParquetCompatibilityTest + with ParquetTest + +/** + * Abstract helper class for testing Parquet compatibility with a pluggable + * [[org.apache.spark.sql.SQLContext]]. + */ +private[sql] abstract class AbstractParquetCompatibilityTest + extends QueryTest + with AbstractParquetTest + with BeforeAndAfterAll { + protected var parquetStore: File = _ /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index e30e2f503e54..0eb3358efa24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,8 +23,8 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.test.{AbstractSQLTestUtils, SQLTestUtils} +import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,7 +33,13 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SparkFunSuite with SQLTestUtils { +private[sql] trait ParquetTest extends AbstractParquetTest with SQLTestUtils + +/** + * Abstract helper trait for Parquet tests with a pluggable [[SQLContext]]. + */ +private[sql] trait AbstractParquetTest extends SparkFunSuite with AbstractSQLTestUtils { + protected def _sqlContext: SQLContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -43,7 +49,7 @@ private[sql] trait ParquetTest extends SparkFunSuite with SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -55,7 +61,7 @@ private[sql] trait ParquetTest extends SparkFunSuite with SQLTestUtils { protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(sqlContext.read.parquet(path))) + withParquetFile(data)(path => f(_sqlContext.read.parquet(path))) } /** @@ -67,14 +73,14 @@ private[sql] trait ParquetTest extends SparkFunSuite with SQLTestUtils { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetDataFrame(data) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + _sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala index 1a2dcba81648..c1a53cac3ec8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala @@ -64,34 +64,25 @@ private[spark] class MyLocalSQLContext(sc: SparkContext) extends SQLContext(sc) private[spark] trait MyTestSQLContext extends SparkFunSuite with BeforeAndAfterAll { /** - * The [[SQLContext]] to use for all tests in this suite. + * The [[MyLocalSQLContext]] to use for all tests in this suite. * * By default, the underlying [[SparkContext]] will be run in local mode with the default * test configurations. */ - private var _ctx: SQLContext = new MyLocalSQLContext - - /** The [[SQLContext]] to use for all tests in this suite. */ - protected def sqlContext: SQLContext = _ctx + private var _ctx: MyLocalSQLContext = new MyLocalSQLContext /** * The [[MyLocalSQLContext]] to use for all tests in this suite. - * This one comes with all the data prepared in advance. */ - protected def sqlContextWithData: MyLocalSQLContext = { - _ctx match { - case local: MyLocalSQLContext => local - case _ => fail("this SQLContext does not have data prepared in advance") - } - } + protected def sqlContext: MyLocalSQLContext = _ctx /** - * Switch to the provided [[SQLContext]]. + * Switch to the provided [[MyLocalSQLContext]]. * * This stops the underlying [[SparkContext]] and expects a new one to be created. * This is needed because only one [[SparkContext]] is allowed per JVM. */ - protected def switchSQLContext(newContext: () => SQLContext): Unit = { + protected def switchSQLContext(newContext: () => MyLocalSQLContext): Unit = { if (_ctx != null) { _ctx.sparkContext.stop() _ctx = newContext() @@ -99,10 +90,10 @@ private[spark] trait MyTestSQLContext extends SparkFunSuite with BeforeAndAfterA } /** - * Execute the given block of code with a custom [[SQLContext]]. - * At the end of the method, a [[MyLocalSQLContext]] will be restored. + * Execute the given block of code with a custom [[MyLocalSQLContext]]. + * At the end of the method, the default [[MyLocalSQLContext]] will be restored. */ - protected def withSQLContext[T](newContext: () => SQLContext)(body: => T) { + protected def withSQLContext[T](newContext: () => MyLocalSQLContext)(body: => T) { switchSQLContext(newContext) try { body diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 5d3af085bd60..e1b7f6ada7ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -26,6 +26,9 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils +/** + * General helper trait for common functionality in SQL tests. + */ private[spark] trait SQLTestUtils extends SparkFunSuite with AbstractSQLTestUtils @@ -34,6 +37,9 @@ private[spark] trait SQLTestUtils protected final override def _sqlContext = sqlContext } +/** + * Abstract helper trait for SQL tests with a pluggable [[SQLContext]]. + */ private[spark] trait AbstractSQLTestUtils { this: SparkFunSuite => protected def _sqlContext: SQLContext diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala new file mode 100644 index 000000000000..afcd0c97d11c --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.apache.spark.sql.parquet.AbstractParquetCompatibilityTest + +/** + * Helper class for testing Parquet compatibility in hive. + * This is analogous to [[org.apache.spark.sql.parquet.ParquetCompatibilityTest]]. + */ +private[hive] abstract class HiveParquetCompatibilityTest + extends AbstractParquetCompatibilityTest + with HiveParquetTest diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala new file mode 100644 index 000000000000..b909a9c88829 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.apache.spark.sql.parquet.AbstractParquetTest + +/** + * Helper trait for Parquet tests analogous to [[org.apache.spark.sql.parquet.ParquetTest]]. + */ +private[hive] trait HiveParquetTest extends AbstractParquetTest with HiveTestUtils + diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala new file mode 100644 index 000000000000..13b8ccab33c4 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.apache.spark.sql.execution.AbstractSparkPlanTest +import org.apache.spark.sql.SQLContext + +/** + * Base class for writing tests for individual physical operators in hive. + * This is analogous to [[org.apache.spark.sql.execution.SparkPlanTest]]. + */ +private[sql] abstract class HiveSparkPlanTest + extends AbstractSparkPlanTest + with MyTestHiveContext { + + protected override def _sqlContext: SQLContext = hiveContext +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala index 266282e228e6..f1635af9c857 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.AbstractSQLTestUtils /** - * This is analogous to [[org.apache.spark.sql.test.SQLTestUtils]] but for hive tests. + * Helper trait analogous to [[org.apache.spark.sql.test.SQLTestUtils]] but for hive tests. */ private[spark] trait HiveTestUtils extends SparkFunSuite diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 59464a048971..87ce5e815e91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,17 +17,13 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.hive.test.HiveParquetTest import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) -class HiveParquetSuite extends QueryTest with ParquetTest { - - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext +class HiveParquetSuite extends QueryTest with HiveParquetTest { + private val ctx = hiveContext import ctx._ test("Case insensitive attribute names") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index c967744bcdf1..d14db04eb990 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf -import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.hive.test.HiveParquetCompatibilityTest import org.apache.spark.sql.parquet.ParquetCompatibilityTest import org.apache.spark.sql.{Row, SQLConf} -class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { +class ParquetHiveCompatibilitySuite extends HiveParquetCompatibilityTest { import ParquetCompatibilityTest.makeNullable - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext + private val ctx = hiveContext /** * Set the staging directory (and hence path to ignore Parquet files under) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index ab1d5725dba1..2a6d9d0d57bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -24,15 +24,12 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} -import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.hive.test.HiveSparkPlanTest import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends SparkPlanTest { - - // Use a hive context instead - switchSQLContext(() => new TestHiveContext) - private val ctx = sqlContext.asInstanceOf[TestHiveContext] +class ScriptTransformationSuite extends HiveSparkPlanTest { + private val ctx = hiveContext private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, From 6345cee8f4fca7d5819bfddad51cdf6fe5d4bfe8 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Aug 2015 11:46:25 -0700 Subject: [PATCH 05/39] Clean up JsonSuite --- .../org/apache/spark/sql/json/JsonSuite.scala | 96 +++++++++---------- .../apache/spark/sql/json/TestJsonData.scala | 37 ++++--- 2 files changed, 64 insertions(+), 69 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 037dcc738b94..79a25d53b7b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -23,20 +23,18 @@ import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf, SQLContext} +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.json.InferSchema.compatibleType -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { - private val _ctx = sqlContext - import _ctx.implicits._ - import _ctx._ - - protected override def ctx: SQLContext = _ctx +class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { + private val ctx = sqlContext + import ctx.implicits._ + import ctx._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -220,7 +218,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("Complex field and type inferring with null in sampling") { - val jsonDF = ctx.read.json(jsonNullStruct) + val jsonDF = _sqlContext.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -239,7 +237,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("Primitive field and type inferring") { - val jsonDF = ctx.read.json(primitiveFieldAndType) + val jsonDF = _sqlContext.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: @@ -267,7 +265,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("Complex field and type inferring") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = _sqlContext.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -366,7 +364,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("GetField operation on complex data type") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = _sqlContext.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -382,7 +380,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("Type conflict in primitive field values") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = _sqlContext.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -456,7 +454,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = _sqlContext.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -509,7 +507,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("Type conflict in complex field values") { - val jsonDF = ctx.read.json(complexFieldValueTypeConflict) + val jsonDF = _sqlContext.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -533,7 +531,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("Type conflict in array elements") { - val jsonDF = ctx.read.json(arrayElementTypeConflict) + val jsonDF = _sqlContext.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -561,7 +559,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("Handling missing fields") { - val jsonDF = ctx.read.json(missingFields) + val jsonDF = _sqlContext.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -580,9 +578,9 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - ctx.sparkContext.parallelize(1 to 100) + _sqlContext.sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) + val jsonDF = _sqlContext.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -597,7 +595,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + _sqlContext.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.path === Some(path)) assert(relationWithSchema.schema === schema) @@ -609,7 +607,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = ctx.read.json(path) + val jsonDF = _sqlContext.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: @@ -678,7 +676,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = ctx.read.schema(schema).json(path) + val jsonDF1 = _sqlContext.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -695,7 +693,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { "this is a simple string.") ) - val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = _sqlContext.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -716,7 +714,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = _sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -744,7 +742,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = _sqlContext.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -770,7 +768,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = _sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -788,7 +786,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("SPARK-3390 Complex arrays") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = _sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -811,7 +809,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = ctx.read.json(jsonArray) + val jsonDF = _sqlContext.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -829,10 +827,10 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val oldColumnNameOfCorruptRecord = _sqlContext.conf.columnNameOfCorruptRecord + _sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonDF = ctx.read.json(corruptRecords) + val jsonDF = _sqlContext.read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -882,11 +880,11 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { Row("]") :: Nil ) - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + _sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } test("SPARK-4068: nulls in arrays") { - val jsonDF = ctx.read.json(nullsInArrays) + val jsonDF = _sqlContext.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -932,7 +930,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = ctx.createDataFrame(rowRDD1, schema1) + val df1 = _sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -955,7 +953,7 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = ctx.createDataFrame(rowRDD2, schema2) + val df3 = _sqlContext.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -963,8 +961,8 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = ctx.read.json(primitiveFieldAndType) - val primTable = ctx.read.json(jsonDF.toJSON) + val jsonDF = _sqlContext.read.json(primitiveFieldAndType) + val primTable = _sqlContext.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -976,8 +974,8 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { "this is a simple string.") ) - val complexJsonDF = ctx.read.json(complexFieldAndType1) - val compTable = ctx.read.json(complexJsonDF.toJSON) + val complexJsonDF = _sqlContext.read.json(complexFieldAndType1) + val compTable = _sqlContext.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1045,19 +1043,19 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { "path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - ctx) + _sqlContext) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( "path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - ctx) + _sqlContext) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( "path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)), - ctx) + _sqlContext) val logicalRelation3 = LogicalRelation(relation3) assert(relation1 === relation2) @@ -1080,24 +1078,24 @@ class JsonSuite extends QueryTest with TestJsonData with MyTestSQLContext { } test("SPARK-7565 MapType in JsonRDD") { - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val oldColumnNameOfCorruptRecord = _sqlContext.conf.columnNameOfCorruptRecord + _sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) try { val temp = Utils.createTempDir().getPath - val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + val df = _sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) df.write.mode("overwrite").parquet(temp) // order of MapType is not defined - assert(ctx.read.parquet(temp).count() == 5) + assert(_sqlContext.read.parquet(temp).count() == 5) - val df2 = ctx.read.json(corruptRecords) + val df2 = _sqlContext.read.json(corruptRecords) df2.write.mode("overwrite").parquet(temp) - checkAnswer(ctx.read.parquet(temp), df2.collect()) + checkAnswer(_sqlContext.read.parquet(temp), df2.collect()) } finally { - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + _sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 182063097e32..4a06074a7864 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -20,14 +20,11 @@ package org.apache.spark.sql.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -// TODO: clean me up - -trait TestJsonData { - - protected def ctx: SQLContext +private[json] trait TestJsonData { + protected def _sqlContext: SQLContext def primitiveFieldAndType: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -38,7 +35,7 @@ trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -49,14 +46,14 @@ trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -67,14 +64,14 @@ trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -82,7 +79,7 @@ trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -98,7 +95,7 @@ trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -152,7 +149,7 @@ trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -160,7 +157,7 @@ trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -169,21 +166,21 @@ trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -192,7 +189,7 @@ trait TestJsonData { """]""" :: Nil) def emptyRecords: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -200,5 +197,5 @@ trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) - def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) + def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]()) } From 68ac6fe3994cbe3aa4a51909cf9e130e8c46cfb1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Aug 2015 12:05:35 -0700 Subject: [PATCH 06/39] Rename the test traits properly MyTestSQLContext -> SharedSQLContext MyTestHiveContext -> SharedHiveContext LocalSQLContext -> TestSQLContext TestData -> TestSQLData --- .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../spark/sql/DataFrameAggregateSuite.scala | 4 +- .../spark/sql/DataFrameFunctionsSuite.scala | 4 +- .../spark/sql/DataFrameImplicitsSuite.scala | 4 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 4 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 4 +- .../apache/spark/sql/DataFrameStatSuite.scala | 4 +- .../apache/spark/sql/DateFunctionsSuite.scala | 4 +- .../org/apache/spark/sql/JoinSuite.scala | 4 +- .../apache/spark/sql/ListTablesSuite.scala | 4 +- .../spark/sql/MathExpressionsSuite.scala | 4 +- .../scala/org/apache/spark/sql/RowSuite.scala | 4 +- .../org/apache/spark/sql/SQLConfSuite.scala | 4 +- .../apache/spark/sql/SQLContextSuite.scala | 4 +- .../sql/ScalaReflectionRelationSuite.scala | 4 +- .../apache/spark/sql/SerializationSuite.scala | 4 +- .../spark/sql/StringFunctionsSuite.scala | 4 +- .../spark/sql/UserDefinedTypeSuite.scala | 4 +- .../columnar/InMemoryColumnarQuerySuite.scala | 6 +- .../columnar/PartitionBatchPruningSuite.scala | 4 +- .../spark/sql/execution/SparkPlanTest.scala | 4 +- .../execution/SparkSqlSerializer2Suite.scala | 6 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 4 +- .../UnsafeKVExternalSorterSuite.scala | 4 +- .../sql/execution/debug/DebuggingSuite.scala | 4 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 4 +- .../spark/sql/sources/DataSourceTest.scala | 4 +- .../spark/sql/test/MyTestSQLContext.scala | 109 ------------------ .../apache/spark/sql/test/SQLTestUtils.scala | 15 ++- .../spark/sql/test/SharedSQLContext.scala | 77 +++++++++++++ .../spark/sql/test/TestSQLContext.scala | 45 ++++++++ .../{MyTestData.scala => TestSQLData.scala} | 86 +++++++------- .../spark/sql/hive/test/TestHiveContext.scala | 2 +- .../sql/hive/test/HiveSparkPlanTest.scala | 2 +- .../spark/sql/hive/test/HiveTestUtils.scala | 2 +- ...eContext.scala => SharedHiveContext.scala} | 29 +++-- .../spark/sql/hive/CachedTableSuite.scala | 4 +- .../spark/sql/hive/ErrorPositionSuite.scala | 4 +- .../hive/HiveDataFrameAnalyticsSuite.scala | 4 +- .../sql/hive/HiveDataFrameJoinSuite.scala | 4 +- .../sql/hive/HiveDataFrameWindowSuite.scala | 4 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 4 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 4 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 4 +- .../spark/sql/hive/ListTablesSuite.scala | 4 +- .../spark/sql/hive/QueryPartitionSuite.scala | 4 +- .../spark/sql/hive/SerializationSuite.scala | 4 +- .../spark/sql/hive/StatisticsSuite.scala | 4 +- .../org/apache/spark/sql/hive/UDFSuite.scala | 4 +- .../execution/BigDataBenchmarkSuite.scala | 4 +- .../hive/execution/HiveComparisonTest.scala | 4 +- .../sql/hive/execution/HiveExplainSuite.scala | 4 +- .../HiveOperatorQueryableSuite.scala | 4 +- .../sql/hive/execution/HivePlanTest.scala | 4 +- .../sql/hive/execution/HiveQuerySuite.scala | 4 +- .../sql/hive/execution/HiveSerDeSuite.scala | 4 +- .../sql/hive/execution/HiveUDFSuite.scala | 4 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 4 +- .../spark/sql/hive/orc/OrcSourceSuite.scala | 4 +- 60 files changed, 302 insertions(+), 273 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala rename sql/core/src/test/scala/org/apache/spark/sql/test/{MyTestData.scala => TestSQLData.scala} (70%) rename sql/hive/src/test/java/org/apache/spark/sql/hive/test/{MyTestHiveContext.scala => SharedHiveContext.scala} (72%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 292e88a3187e..7bd5b9c73865 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,11 +25,11 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.columnar._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext private case class BigData(s: String) -class CachedTableSuite extends QueryTest with MyTestSQLContext { +class CachedTableSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8bbfc92dac72..7238b8dc1ee8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType -class DataFrameAggregateSuite extends QueryTest with MyTestSQLContext { +class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 60ef6ac059fa..903ca6fa9475 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest with MyTestSQLContext { +class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 9b598112732a..6205a73f03dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameImplicitsSuite extends QueryTest with MyTestSQLContext { +class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 698211dfdab2..216a24a533dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameJoinSuite extends QueryTest with MyTestSQLContext { +class DataFrameJoinSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index f0ae14935422..97302bf88d8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameNaFunctionsSuite extends QueryTest with MyTestSQLContext { +class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 9096af0251b1..8ea9fa9f88ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -22,9 +22,9 @@ import java.util.Random import org.scalatest.Matchers._ import org.apache.spark.sql.functions.col -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameStatSuite extends QueryTest with MyTestSQLContext { +class DataFrameStatSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 21a3515ecb08..586c3736556b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,10 +22,10 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval -class DateFunctionsSuite extends QueryTest with MyTestSQLContext { +class DateFunctionsSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 7d51e68cd1a1..1cba0b769542 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class JoinSuite extends QueryTest with BeforeAndAfterEach with MyTestSQLContext { +class JoinSuite extends QueryTest with BeforeAndAfterEach with SQLTestUtils { private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 4b366482153f..87e37837c6e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -class ListTablesSuite extends QueryTest with BeforeAndAfter with MyTestSQLContext { +class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index bf2525e1568a..b05c0def05d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest with MyTestSQLContext { +class MathExpressionsSuite extends QueryTest with SharedSQLContext { import MathExpressionsTestData._ private val ctx = sqlContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index be0ad1d67fdd..c30f44e807e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite with MyTestSQLContext { +class RowSuite extends SparkFunSuite with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index eca293ad20d9..508105a0d056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class SQLConfSuite extends QueryTest with MyTestSQLContext { +class SQLConfSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext private val testKey = "test.key.0" private val testVal = "test.val.0" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index ee77e7b72e27..038e9b069cab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class SQLContextSuite extends SparkFunSuite with MyTestSQLContext { +class SQLContextSuite extends SparkFunSuite with SharedSQLContext { private val ctx = sqlContext override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 9a374f52661d..e117ca4c511c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext case class ReflectData( stringField: String, @@ -72,7 +72,7 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite with MyTestSQLContext { +class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 31abd7835268..45d0ee4a8e74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class SerializationSuite extends SparkFunSuite with MyTestSQLContext { +class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { val _sqlContext = new SQLContext(sqlContext.sparkContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index b6a1b6c67917..f8a89f323047 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.Decimal -class StringFunctionsSuite extends QueryTest with MyTestSQLContext { +class StringFunctionsSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index f060851a363a..ddf3c184839b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,7 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -67,7 +67,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest with MyTestSQLContext { +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 6bf25110e6a5..ab7c5e47cb77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.columnar import java.sql.{Date, Timestamp} -import org.apache.spark.sql.test.MyTestSQLContext -import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel.MEMORY_ONLY -class InMemoryColumnarQuerySuite extends QueryTest with MyTestSQLContext { +class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index e65d5e3d22c5..591e1ff4a789 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -21,9 +21,9 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with MyTestSQLContext { +class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 1b4994057239..94f56b456147 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -25,13 +25,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -private[sql] abstract class SparkPlanTest extends AbstractSparkPlanTest with MyTestSQLContext { +private[sql] abstract class SparkPlanTest extends AbstractSparkPlanTest with SharedSQLContext { protected override def _sqlContext: SQLContext = sqlContext } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index f55503ba1d74..8fe68458e824 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -22,11 +22,11 @@ import java.sql.{Timestamp, Date} import org.apache.spark.serializer.Serializer import org.apache.spark.{ShuffleDependency, SparkFunSuite} import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.Row -class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite with MyTestSQLContext { +class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite with SharedSQLContext { // Make sure that we will not use serializer2 for unsupported data types. def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { val testName = @@ -65,7 +65,7 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite with MyTestSQLConte checkSupported(new MyDenseVectorUDT, isSupported = false) } -abstract class SparkSqlSerializer2Suite extends QueryTest with MyTestSQLContext { +abstract class SparkSqlSerializer2Suite extends QueryTest with SharedSQLContext { protected val ctx = sqlContext var allColumns: String = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 9f1101fe6aa2..7eaa16fe8952 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -39,7 +39,7 @@ import org.apache.spark.unsafe.types.UTF8String class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers - with MyTestSQLContext { + with SharedSQLContext { import UnsafeFixedWidthAggregationMap._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index bb4ab9f1e986..72c5ae1fa051 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -23,14 +23,14 @@ import org.apache.spark._ import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{RowOrdering, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. */ -class UnsafeKVExternalSorterSuite extends SparkFunSuite with MyTestSQLContext { +class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 0bb7f6e0b8e9..b242ff49542a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class DebuggingSuite extends SparkFunSuite with MyTestSQLContext { +class DebuggingSuite extends SparkFunSuite with SharedSQLContext { private val ctx = sqlContext test("DataFrame.debug()") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 6f1206230767..04c3606d648f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,11 +25,11 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter with MyTestSQLContext { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 073171743683..0284a7a5e858 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -24,11 +24,11 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with MyTestSQLContext { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { private val ctx = sqlContext import ctx.implicits._ import ctx._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index fa253105bdb7..6b7200bddfee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.sources import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ -import org.apache.spark.sql.test.MyTestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -abstract class DataSourceTest extends QueryTest with BeforeAndAfter with MyTestSQLContext { +abstract class DataSourceTest extends QueryTest with BeforeAndAfter with SharedSQLContext { // We want to test some edge cases. protected implicit lazy val caseInsensitiveContext = { val ctx = new SQLContext(sqlContext.sparkContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala deleted file mode 100644 index c1a53cac3ec8..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestSQLContext.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import scala.language.implicitConversions - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** - * A SQLContext that can be used for local testing. - */ -private[spark] class MyLocalSQLContext(sc: SparkContext) extends SQLContext(sc) with MyTestData { - - def this() { - this(new SparkContext("local[2]", "test-sql-context", - new SparkConf().set("spark.sql.testkey", "true"))) - } - - // For test data - protected override val sqlContext: SQLContext = this - - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - - protected[sql] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** Fewer partitions to speed up testing. */ - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) - } - } - - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to - * construct [[DataFrame]] directly out of local data without relying on implicits. - */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(this, plan) - } -} - -/** - * A scalatest trait for test suites where all tests share a single [[SQLContext]]. - */ -private[spark] trait MyTestSQLContext extends SparkFunSuite with BeforeAndAfterAll { - - /** - * The [[MyLocalSQLContext]] to use for all tests in this suite. - * - * By default, the underlying [[SparkContext]] will be run in local mode with the default - * test configurations. - */ - private var _ctx: MyLocalSQLContext = new MyLocalSQLContext - - /** - * The [[MyLocalSQLContext]] to use for all tests in this suite. - */ - protected def sqlContext: MyLocalSQLContext = _ctx - - /** - * Switch to the provided [[MyLocalSQLContext]]. - * - * This stops the underlying [[SparkContext]] and expects a new one to be created. - * This is needed because only one [[SparkContext]] is allowed per JVM. - */ - protected def switchSQLContext(newContext: () => MyLocalSQLContext): Unit = { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = newContext() - } - } - - /** - * Execute the given block of code with a custom [[MyLocalSQLContext]]. - * At the end of the method, the default [[MyLocalSQLContext]] will be restored. - */ - protected def withSQLContext[T](newContext: () => MyLocalSQLContext)(body: => T) { - switchSQLContext(newContext) - try { - body - } finally { - switchSQLContext(() => new MyLocalSQLContext) - } - } - - protected override def afterAll(): Unit = { - switchSQLContext(() => null) - super.afterAll() - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index e1b7f6ada7ab..57368bdc24d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -21,9 +21,11 @@ import java.io.File import java.util.UUID import scala.util.Try +import scala.language.implicitConversions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.util.Utils /** @@ -32,7 +34,7 @@ import org.apache.spark.util.Utils private[spark] trait SQLTestUtils extends SparkFunSuite with AbstractSQLTestUtils - with MyTestSQLContext { + with SharedSQLContext { protected final override def _sqlContext = sqlContext } @@ -128,4 +130,13 @@ private[spark] trait AbstractSQLTestUtils { this: SparkFunSuite => _sqlContext.sql(s"USE $db") try f finally _sqlContext.sql(s"USE default") } + + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + DataFrame(_sqlContext, plan) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala new file mode 100644 index 000000000000..0473e4666544 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + */ +private[sql] trait SharedSQLContext extends SparkFunSuite with BeforeAndAfterAll { + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _ctx: TestSQLContext = new TestSQLContext + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected def sqlContext: TestSQLContext = _ctx + + /** + * Switch to a custom [[TestSQLContext]]. + * + * This stops the underlying [[org.apache.spark.SparkContext]] and expects a new one + * to be created. This is necessary because only one [[org.apache.spark.SparkContext]] + * is allowed per JVM. + */ + protected def switchSQLContext(newContext: () => TestSQLContext): Unit = { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = newContext() + } + } + + /** + * Execute the given block of code with a custom [[TestSQLContext]]. + * At the end of the method, the default [[TestSQLContext]] will be restored. + */ + protected def withSQLContext[T](newContext: () => TestSQLContext)(body: => T) { + switchSQLContext(newContext) + try { + body + } finally { + switchSQLContext(() => new TestSQLContext) + } + } + + protected override def afterAll(): Unit = { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + super.afterAll() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala new file mode 100644 index 000000000000..7f179ea5b749 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.{SQLConf, SQLContext} + +/** + * A special [[SQLContext]] prepared for testing. + */ +private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) with TestSQLData { + + def this() { + this(new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) + } + + // For SQLTestData + protected override val _sqlContext: SQLContext = this + + // Use fewer paritions to speed up testing + override protected[sql] def createSession(): SQLSession = new this.SQLSession() + + /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ + protected[sql] class SQLSession extends super.SQLSession { + protected[sql] override lazy val conf: SQLConf = new SQLConf { + override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLData.scala similarity index 70% rename from sql/core/src/test/scala/org/apache/spark/sql/test/MyTestData.scala rename to sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLData.scala index 30ffb07ffb4a..e080c40813ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/MyTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLData.scala @@ -23,21 +23,21 @@ import org.apache.spark.sql.{DataFrame, SQLContext} /** * A collection of sample data used in SQL tests. */ -private[spark] trait MyTestData { - protected val sqlContext: SQLContext - import sqlContext.implicits._ +private[sql] trait TestSQLData { + protected val _sqlContext: SQLContext + import _sqlContext.implicits._ // All test data should be lazy because the SQLContext is not set up yet lazy val testData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("testData") df } lazy val testData2: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: @@ -49,7 +49,7 @@ private[spark] trait MyTestData { } lazy val testData3: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( TestData3(1, None) :: TestData3(2, Some(2)) :: Nil).toDF() df.registerTempTable("testData3") @@ -57,14 +57,14 @@ private[spark] trait MyTestData { } lazy val negativeData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() df.registerTempTable("negativeData") df } lazy val largeAndSmallInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: @@ -76,7 +76,7 @@ private[spark] trait MyTestData { } lazy val decimalData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, 2) :: DecimalData(2, 1) :: @@ -88,7 +88,7 @@ private[spark] trait MyTestData { } lazy val binaryData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( BinaryData("12".getBytes, 1) :: BinaryData("22".getBytes, 5) :: BinaryData("122".getBytes, 3) :: @@ -99,7 +99,7 @@ private[spark] trait MyTestData { } lazy val upperCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( UpperCaseData(1, "A") :: UpperCaseData(2, "B") :: UpperCaseData(3, "C") :: @@ -111,7 +111,7 @@ private[spark] trait MyTestData { } lazy val lowerCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: @@ -121,7 +121,7 @@ private[spark] trait MyTestData { } lazy val arrayData: RDD[ArrayData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = _sqlContext.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) rdd.toDF().registerTempTable("arrayData") @@ -129,7 +129,7 @@ private[spark] trait MyTestData { } lazy val mapData: RDD[MapData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = _sqlContext.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: @@ -140,13 +140,13 @@ private[spark] trait MyTestData { } lazy val repeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("repeatedData") rdd } lazy val nullableRepeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = _sqlContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("nullableRepeatedData") @@ -154,7 +154,7 @@ private[spark] trait MyTestData { } lazy val nullInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( NullInts(1) :: NullInts(2) :: NullInts(3) :: @@ -164,7 +164,7 @@ private[spark] trait MyTestData { } lazy val allNulls: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( NullInts(null) :: NullInts(null) :: NullInts(null) :: @@ -174,7 +174,7 @@ private[spark] trait MyTestData { } lazy val nullStrings: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil).toDF() @@ -183,13 +183,13 @@ private[spark] trait MyTestData { } lazy val tableName: DataFrame = { - val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() df.registerTempTable("tableName") df } lazy val unparsedStrings: RDD[String] = { - sqlContext.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( "1, A1, true, null" :: "2, B2, false, null" :: "3, C3, true, null" :: @@ -198,13 +198,13 @@ private[spark] trait MyTestData { // An RDD with 4 elements and 8 partitions lazy val withEmptyParts: RDD[IntField] = { - val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) rdd.toDF().registerTempTable("withEmptyParts") rdd } lazy val person: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( Person(0, "mike", 30) :: Person(1, "jim", 20) :: Nil).toDF() df.registerTempTable("person") @@ -212,7 +212,7 @@ private[spark] trait MyTestData { } lazy val salary: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil).toDF() df.registerTempTable("salary") @@ -220,7 +220,7 @@ private[spark] trait MyTestData { } lazy val complexData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = _sqlContext.sparkContext.parallelize( ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: Nil).toDF() @@ -232,22 +232,22 @@ private[spark] trait MyTestData { | Case classes used in test data | * ------------------------------ */ - private[spark] case class TestData(key: Int, value: String) - private[spark] case class TestData2(a: Int, b: Int) - private[spark] case class TestData3(a: Int, b: Option[Int]) - private[spark] case class LargeAndSmallInts(a: Int, b: Int) - private[spark] case class DecimalData(a: BigDecimal, b: BigDecimal) - private[spark] case class BinaryData(a: Array[Byte], b: Int) - private[spark] case class UpperCaseData(N: Int, L: String) - private[spark] case class LowerCaseData(n: Int, l: String) - private[spark] case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - private[spark] case class MapData(data: scala.collection.Map[Int, String]) - private[spark] case class StringData(s: String) - private[spark] case class IntField(i: Int) - private[spark] case class NullInts(a: Integer) - private[spark] case class NullStrings(n: Int, s: String) - private[spark] case class TableName(tableName: String) - private[spark] case class Person(id: Int, name: String, age: Int) - private[spark] case class Salary(personId: Int, salary: Double) - private[spark] case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) + private[sql] case class TestData(key: Int, value: String) + private[sql] case class TestData2(a: Int, b: Int) + private[sql] case class TestData3(a: Int, b: Option[Int]) + private[sql] case class LargeAndSmallInts(a: Int, b: Int) + private[sql] case class DecimalData(a: BigDecimal, b: BigDecimal) + private[sql] case class BinaryData(a: Array[Byte], b: Int) + private[sql] case class UpperCaseData(N: Int, L: String) + private[sql] case class LowerCaseData(n: Int, l: String) + private[sql] case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + private[sql] case class MapData(data: scala.collection.Map[Int, String]) + private[sql] case class StringData(s: String) + private[sql] case class IntField(i: Int) + private[sql] case class NullInts(a: Integer) + private[sql] case class NullStrings(n: Int, s: String) + private[sql] case class TableName(tableName: String) + private[sql] case class Person(id: Int, name: String, age: Int) + private[sql] case class Salary(personId: Int, salary: Double) + private[sql] case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala index 33e051363d20..6fc4518fa78f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -37,7 +38,6 @@ import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext} /* Implicit conversions */ import scala.collection.JavaConversions._ diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala index 13b8ccab33c4..2582461f2bae 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.SQLContext */ private[sql] abstract class HiveSparkPlanTest extends AbstractSparkPlanTest - with MyTestHiveContext { + with SharedHiveContext { protected override def _sqlContext: SQLContext = hiveContext } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala index f1635af9c857..2357477f1bfd 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.test.AbstractSQLTestUtils private[spark] trait HiveTestUtils extends SparkFunSuite with AbstractSQLTestUtils - with MyTestHiveContext { + with SharedHiveContext { protected final override def _sqlContext = hiveContext } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/MyTestHiveContext.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala similarity index 72% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/test/MyTestHiveContext.scala rename to sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala index ef0fe2e2e27c..efec16cf74da 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/MyTestHiveContext.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala @@ -23,28 +23,30 @@ import org.apache.spark.SparkFunSuite /** - * A scalatest trait for test suites where all tests share a single - * [[org.apache.spark.sql.hive.HiveContext]]. + * Helper trait for hive test suites where all tests share a single [[TestHiveContext]]. + * This is analogous to [[org.apache.spark.sql.test.SharedSQLContext]]. */ -private[spark] trait MyTestHiveContext extends SparkFunSuite with BeforeAndAfterAll { +private[spark] trait SharedHiveContext extends SparkFunSuite with BeforeAndAfterAll { /** * The [[TestHiveContext]] to use for all tests in this suite. * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local mode - * with the default test configurations. + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. */ private var _ctx: TestHiveContext = new TestHiveContext - /** The [[TestHiveContext]] to use for all tests in this suite. */ + /** + * The [[TestHiveContext]] to use for all tests in this suite. + */ protected def hiveContext: TestHiveContext = _ctx /** - * Switch to the provided [[org.apache.spark.sql.hive.HiveContext]]. + * Switch a custom [[TestHiveContext]]. * - * This stops the underlying [[org.apache.spark.SparkContext]] and expects a new one to - * be created. This is needed because only one [[org.apache.spark.SparkContext]] is - * allowed per JVM. + * This stops the underlying [[org.apache.spark.SparkContext]] and expects a new one + * to be created. This is needed because only one [[org.apache.spark.SparkContext]] + * is allowed per JVM. */ protected def switchHiveContext(newContext: () => TestHiveContext): Unit = { if (_ctx != null) { @@ -55,7 +57,7 @@ private[spark] trait MyTestHiveContext extends SparkFunSuite with BeforeAndAfter /** * Execute the given block of code with a custom [[TestHiveContext]]. - * At the end of the method, a [[TestHiveContext]] will be restored. + * At the end of the method, the default [[TestHiveContext]] will be restored. */ protected def withHiveContext[T](newContext: () => TestHiveContext)(body: => T) { switchHiveContext(newContext) @@ -67,7 +69,10 @@ private[spark] trait MyTestHiveContext extends SparkFunSuite with BeforeAndAfter } protected override def afterAll(): Unit = { - switchHiveContext(() => null) + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } super.afterAll() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 0ebf7f1bd855..78f3db14ef10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -21,11 +21,11 @@ import java.io.File import org.apache.spark.sql.{SaveMode, AnalysisException, QueryTest} import org.apache.spark.sql.columnar.InMemoryColumnarTableScan -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest with MyTestHiveContext { +class CachedTableSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 148ab31e91b3..c0aa6c281d0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,11 +22,11 @@ import scala.util.Try import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.{AnalysisException, QueryTest} -class ErrorPositionSuite extends QueryTest with BeforeAndAfter with MyTestHiveContext { +class ErrorPositionSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 352b883a5e75..2f0d54aba112 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with MyTestHiveContext { +class HiveDataFrameAnalyticsSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index 639bb66841b1..22b6e1d25804 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext -class HiveDataFrameJoinSuite extends QueryTest with MyTestHiveContext { +class HiveDataFrameJoinSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 49fb287075f8..14f0c0252013 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext -class HiveDataFrameWindowSuite extends QueryTest with MyTestHiveContext { +class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 9c3ae9f64076..02c5a4b4c92e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.hive import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends SparkFunSuite with MyTestHiveContext with Logging { +class HiveMetastoreCatalogSuite extends SparkFunSuite with SharedHiveContext with Logging { private val ctx = hiveContext test("struct field should accept underscore in sub-column name") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index ef3d32c3684d..0aae289dc94c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.hive.test.{TestHiveContext, MyTestHiveContext} +import org.apache.spark.sql.hive.test.{TestHiveContext, SharedHiveContext} import org.apache.spark.util.{ResetSystemProperties, Utils} /** @@ -39,7 +39,7 @@ class HiveSparkSubmitSuite with Matchers with ResetSystemProperties with Timeouts - with MyTestHiveContext { + with SharedHiveContext { private val ctx = hiveContext import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index bd931c264840..d02f2ac8beb5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.execution.QueryExecutionException -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -32,7 +32,7 @@ case class TestData(key: Int, value: String) case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with MyTestHiveContext { +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 47281170c691..b6aa12a5f6e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext -class ListTablesSuite extends QueryTest with MyTestHiveContext { +class ListTablesSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 2a6c9800afcb..0f3a5e16088e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.QueryTest import org.apache.spark.util.Utils -class QueryPartitionSuite extends QueryTest with MyTestHiveContext { +class QueryPartitionSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ import ctx.sql diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index 792f8c595779..376dc7ebdd4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.hive import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext -class SerializationSuite extends SparkFunSuite with MyTestHiveContext { +class SerializationSuite extends SparkFunSuite with SharedHiveContext { private val ctx = hiveContext test("[SPARK-5840] HiveContext should be serializable") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index cbe8b27ad10c..ba78cd932631 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,9 +22,9 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext -class StatisticsSuite extends QueryTest with MyTestHiveContext { +class StatisticsSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx.sql diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 454b973e81d1..7355c62d2c8b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with MyTestHiveContext { +class UDFSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext test("UDF case insensitive") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index 115a9308f98a..6f26e2119c66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive.execution import java.io.File -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext /** * A set of test cases based on the big-data-benchmark. * https://amplab.cs.berkeley.edu/benchmark/ */ -class BigDataBenchmarkSuite extends HiveComparisonTest with MyTestHiveContext { +class BigDataBenchmarkSuite extends HiveComparisonTest with SharedHiveContext { import ctx._ private val testDataDirectory = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index bbb6c1668f62..a001bba46f00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext /** * Allows the creations of tests that execute the same query against both hive @@ -42,7 +42,7 @@ import org.apache.spark.sql.hive.test.MyTestHiveContext abstract class HiveComparisonTest extends SparkFunSuite with GivenWhenThen - with MyTestHiveContext + with SharedHiveContext with Logging { protected val ctx = hiveContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 6d7078283383..3cd8ed3125d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest with MyTestHiveContext { +class HiveExplainSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index ab0ac4b5b094..bd7a456fcf50 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext /** * A set of tests that validates commands can also be queried by like a table */ -class HiveOperatorQueryableSuite extends QueryTest with MyTestHiveContext { +class HiveOperatorQueryableSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index 84bb1d48b0a1..1d0ce8e304ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -21,9 +21,9 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext -class HivePlanTest extends QueryTest with MyTestHiveContext { +class HivePlanTest extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index db70491f8ee6..3380129b29a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext case class TestData(a: Int, b: String) @@ -38,7 +38,7 @@ case class TestData(a: Int, b: String) * A set of test cases expressed in Hive QL that are not covered by the tests * included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with MyTestHiveContext { +class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedHiveContext { import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 4c7d464d9bc0..c508c3b7d7df 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext /** * A set of tests that validates support for Hive SerDe. */ -class HiveSerDeSuite extends HiveComparisonTest with MyTestHiveContext { +class HiveSerDeSuite extends HiveComparisonTest with SharedHiveContext { import org.apache.hadoop.hive.serde2.RegexSerDe import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 571c55d8d23c..2604ee5d2794 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.util.Utils case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) @@ -46,7 +46,7 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest with MyTestHiveContext { +class HiveUDFSuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 0b0f821a5371..a273467695cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -25,7 +25,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. @@ -35,7 +35,7 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with MyTestHiveContext { +class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { private val ctx = hiveContext import ctx.implicits._ import ctx._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 3b809ba66201..7e25aab1b9ac 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.MyTestHiveContext +import org.apache.spark.sql.hive.test.SharedHiveContext case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with MyTestHiveContext { +abstract class OrcSuite extends QueryTest with SharedHiveContext { protected val ctx = hiveContext import ctx.implicits._ import ctx._ From b15fdc6713c125b8f3ea0047c752351bcbc84d0c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Aug 2015 13:04:21 -0700 Subject: [PATCH 07/39] Stop SparkContext in Java SQL tests --- .../test/org/apache/spark/sql/JavaApplySchemaSuite.java | 3 ++- .../test/org/apache/spark/sql/JavaDataFrameSuite.java | 3 ++- .../test/java/test/org/apache/spark/sql/JavaUDFSuite.java | 3 +++ .../org/apache/spark/sql/sources/JavaSaveLoadSuite.java | 8 ++++++++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 478210d72d02..446e404d1239 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -55,8 +55,9 @@ public void setUp() { @After public void tearDown() { - javaCtx = null; + sqlContext.sparkContext().stop(); sqlContext = null; + javaCtx = null; } public static class Person implements Serializable { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index f9444f60a08d..c653b0639aa3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -52,8 +52,9 @@ public void setUp() { @After public void tearDown() { - jsc = null; + context.sparkContext().stop(); context = null; + jsc = null; } @Test diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 98c8a4aca6ca..bb02b58cca9b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -47,6 +47,9 @@ public void setUp() { @After public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; } @SuppressWarnings("unchecked") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 90f802a01ef5..6f9e7f68dc39 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.*; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -72,6 +73,13 @@ public void setUp() throws IOException { df.registerTempTable("jsonTable"); } + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; + } + @Test public void saveAndLoad() { Map options = new HashMap(); From 0d74a72b9547c2eb0bf9ca592e9e56c2e46f0836 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Aug 2015 14:20:07 -0700 Subject: [PATCH 08/39] Load test data early in case tables are accessed by name The test data is currently loaded as a bunch of lazy vals. If the data is accessed by name, however, they won't be loaded automatically. This patch adds an explicit method call that loads the data if necessary. --- project/SparkBuild.scala | 2 -- .../apache/spark/sql/JavaDataFrameSuite.java | 6 ++-- .../org/apache/spark/sql/JoinSuite.scala | 2 ++ .../org/apache/spark/sql/SQLQuerySuite.scala | 2 ++ .../columnar/InMemoryColumnarQuerySuite.scala | 2 ++ .../spark/sql/execution/PlannerSuite.scala | 2 ++ .../{TestSQLData.scala => SQLTestData.scala} | 31 ++++++++++++++++++- .../apache/spark/sql/test/SQLTestUtils.scala | 4 +-- .../spark/sql/test/SharedSQLContext.scala | 5 +++ .../spark/sql/test/TestSQLContext.scala | 4 +-- 10 files changed, 51 insertions(+), 9 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/test/{TestSQLData.scala => SQLTestData.scala} (93%) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 15452470e5d5..0cbb23b332d1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -338,8 +338,6 @@ object Hive { javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), - // TODO: re-enable this now that we've gotten rid of the TestHive singleton? - parallelExecution in Test := false, // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. scalacOptions <<= scalacOptions map { currentOpts: Seq[String] => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index c653b0639aa3..9be7af272515 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -34,12 +34,13 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; - private transient SQLContext context; + private transient TestSQLContext context; @Before public void setUp() { @@ -47,7 +48,8 @@ public void setUp() { // TODO: restore the test data here somehow: TestData$.MODULE$.testData(); SparkContext sc = new SparkContext("local[*]", "testing"); jsc = new JavaSparkContext(sc); - context = new SQLContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); } @After diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1cba0b769542..5c7b670c5566 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -29,6 +29,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach with SQLTestUtils { import ctx.implicits._ import ctx._ + ctx.loadTestData() + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a2296f881c31..5f6619b1f09d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -37,6 +37,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { import ctx.implicits._ import ctx._ + ctx.loadTestData() + test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index ab7c5e47cb77..0cc127aaa5f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -29,6 +29,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { import ctx.implicits._ import ctx._ + ctx.loadTestData() + test("simple columnar query") { val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5be2d1e9e1a8..921e4507ca2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -33,6 +33,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { import ctx.planner._ import ctx._ + ctx.loadTestData() + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLData.scala rename to sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index e080c40813ff..ba68d985b2da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} /** * A collection of sample data used in SQL tests. */ -private[sql] trait TestSQLData { +private[sql] trait SQLTestData { protected val _sqlContext: SQLContext import _sqlContext.implicits._ @@ -228,6 +228,35 @@ private[sql] trait TestSQLData { df } + /** + * Initialize all test data such that all temp tables are properly registered. + */ + def loadTestData(): Unit = { + assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + testData + testData2 + testData3 + negativeData + largeAndSmallInts + decimalData + binaryData + upperCaseData + lowerCaseData + arrayData + mapData + repeatedData + nullableRepeatedData + nullInts + allNulls + nullStrings + tableName + unparsedStrings + withEmptyParts + person + salary + complexData + } + /* ------------------------------ * | Case classes used in test data | * ------------------------------ */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 57368bdc24d5..36bce1f184db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils /** * General helper trait for common functionality in SQL tests. */ -private[spark] trait SQLTestUtils +private[sql] trait SQLTestUtils extends SparkFunSuite with AbstractSQLTestUtils with SharedSQLContext { @@ -42,7 +42,7 @@ private[spark] trait SQLTestUtils /** * Abstract helper trait for SQL tests with a pluggable [[SQLContext]]. */ -private[spark] trait AbstractSQLTestUtils { this: SparkFunSuite => +private[sql] trait AbstractSQLTestUtils { this: SparkFunSuite => protected def _sqlContext: SQLContext protected def configuration = _sqlContext.sparkContext.hadoopConfiguration diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 0473e4666544..dd2f8aacdfd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -39,6 +39,11 @@ private[sql] trait SharedSQLContext extends SparkFunSuite with BeforeAndAfterAll */ protected def sqlContext: TestSQLContext = _ctx + /** + * Initialize all test data such that all temp tables are properly registered. + */ + protected final def loadTestData(): Unit = _ctx.loadTestData() + /** * Switch to a custom [[TestSQLContext]]. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 7f179ea5b749..05cbc24ffc83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.{SQLConf, SQLContext} /** * A special [[SQLContext]] prepared for testing. */ -private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) with TestSQLData { +private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) with SQLTestData { def this() { this(new SparkContext("local[2]", "test-sql-context", @@ -34,7 +34,7 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) with protected override val _sqlContext: SQLContext = this // Use fewer paritions to speed up testing - override protected[sql] def createSession(): SQLSession = new this.SQLSession() + protected[sql] override def createSession(): SQLSession = new this.SQLSession() /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ protected[sql] class SQLSession extends super.SQLSession { From eee415d1899162534753aab4774a69374d19cf5f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 10 Aug 2015 15:18:16 -0700 Subject: [PATCH 09/39] Refactor implicits into SQLTestUtils This commit allows us to call `import testImplicits._` in the test constructor and use implicit methods properly. This was previously not possible without also starting a SQLContext in the constructor. Instead, now we can properly use implicits *while* starting the SQLContext in `beforeAll`. However, there is currently an issue with tests using the test data prepared in advance. This will be fixed in the subsequent commit. --- .../org/apache/spark/sql/SQLContext.scala | 95 +--- .../org/apache/spark/sql/SQLImplicits.scala | 122 +++++ .../apache/spark/sql/JavaDataFrameSuite.java | 1 - .../apache/spark/sql/CachedTableSuite.scala | 68 ++- .../spark/sql/ColumnExpressionSuite.scala | 20 +- .../spark/sql/DataFrameAggregateSuite.scala | 8 +- .../spark/sql/DataFrameFunctionsSuite.scala | 8 +- .../spark/sql/DataFrameImplicitsSuite.scala | 7 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 8 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 7 +- .../apache/spark/sql/DataFrameStatSuite.scala | 7 +- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../spark/sql/DataFrameTungstenSuite.scala | 3 +- .../apache/spark/sql/DateFunctionsSuite.scala | 7 +- .../org/apache/spark/sql/JoinSuite.scala | 6 +- .../apache/spark/sql/ListTablesSuite.scala | 7 +- .../spark/sql/MathExpressionsSuite.scala | 8 +- .../scala/org/apache/spark/sql/RowSuite.scala | 7 +- .../org/apache/spark/sql/SQLConfSuite.scala | 5 +- .../apache/spark/sql/SQLContextSuite.scala | 5 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 442 +++++++++--------- .../sql/ScalaReflectionRelationSuite.scala | 7 +- .../spark/sql/StringFunctionsSuite.scala | 7 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 4 +- .../spark/sql/UserDefinedTypeSuite.scala | 7 +- .../columnar/InMemoryColumnarQuerySuite.scala | 28 +- .../columnar/PartitionBatchPruningSuite.scala | 8 +- .../spark/sql/execution/AggregateSuite.scala | 1 - .../spark/sql/execution/PlannerSuite.scala | 39 +- .../execution/RowFormatConvertersSuite.scala | 1 - .../spark/sql/execution/SparkPlanTest.scala | 6 +- .../execution/SparkSqlSerializer2Suite.scala | 8 +- .../sql/execution/TungstenSortSuite.scala | 1 - .../sql/execution/debug/DebuggingSuite.scala | 9 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 82 ++-- .../spark/sql/jdbc/JDBCWriteSuite.scala | 9 +- .../org/apache/spark/sql/json/JsonSuite.scala | 122 +++-- .../sql/parquet/ParquetFilterSuite.scala | 3 +- .../spark/sql/parquet/ParquetIOSuite.scala | 3 +- .../ParquetPartitionDiscoverySuite.scala | 33 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 59 ++- .../spark/sql/sources/DataSourceTest.scala | 4 +- .../apache/spark/sql/test/SQLTestData.scala | 18 +- .../apache/spark/sql/test/SQLTestUtils.scala | 42 +- .../spark/sql/test/SharedSQLContext.scala | 38 +- .../spark/sql/test/TestSQLContext.scala | 16 +- .../spark/sql/hive/test/HiveTestUtils.scala | 12 +- .../sql/hive/test/SharedHiveContext.scala | 32 +- 48 files changed, 739 insertions(+), 705 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbb2a0984654..f7d631d4bb99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConversions._ import scala.collection.immutable -import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -43,7 +42,6 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -331,97 +329,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ @Experimental - object implicits extends Serializable { - // scalastyle:on - - /** - * Converts $"col name" into an [[Column]]. - * @since 1.3.0 - */ - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args : _*)) - } - } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 1.3.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** - * Creates a DataFrame from an RDD of case classes or tuples. - * @since 1.3.0 - */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { - DataFrameHolder(self.createDataFrame(rdd)) - } - - /** - * Creates a DataFrame from a local Seq of Product. - * @since 1.3.0 - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = - { - DataFrameHolder(self.createDataFrame(data)) - } - - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. - - /** - * Creates a single column DataFrame from an RDD[Int]. - * @since 1.3.0 - */ - implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { - val dataType = IntegerType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setInt(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[Long]. - * @since 1.3.0 - */ - implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { - val dataType = LongType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setLong(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[String]. - * @since 1.3.0 - */ - implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { - val dataType = StringType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.update(0, UTF8String.fromString(v)) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } + object implicits extends SQLImplicits with Serializable { + protected override def _sqlContext: SQLContext = self } + // scalastyle:on /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala new file mode 100644 index 000000000000..cabed5c3630b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types.StructField +import org.apache.spark.unsafe.types.UTF8String + +/** + * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + */ +private[sql] abstract class SQLImplicits { + protected def _sqlContext: SQLContext + + /** + * Converts $"col name" into an [[Column]]. + * @since 1.3.0 + */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args : _*)) + } + } + + /** + * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. + * @since 1.3.0 + */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** + * Creates a DataFrame from an RDD of case classes or tuples. + * @since 1.3.0 + */ + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(_sqlContext.createDataFrame(rdd)) + } + + /** + * Creates a DataFrame from a local Seq of Product. + * @since 1.3.0 + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { + DataFrameHolder(_sqlContext.createDataFrame(data)) + } + + // Do NOT add more implicit conversions. They are likely to break source compatibility by + // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous + // because of [[DoubleRDDFunctions]]. + + /** + * Creates a single column DataFrame from an RDD[Int]. + * @since 1.3.0 + */ + implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { + val dataType = IntegerType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setInt(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[Long]. + * @since 1.3.0 + */ + implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { + val dataType = LongType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setLong(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[String]. + * @since 1.3.0 + */ + implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { + val dataType = StringType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.update(0, UTF8String.fromString(v)) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 9be7af272515..656a2e1aa49a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -45,7 +45,6 @@ public class JavaDataFrameSuite { @Before public void setUp() { // Trigger static initializer of TestData - // TODO: restore the test data here somehow: TestData$.MODULE$.testData(); SparkContext sc = new SparkContext("local[*]", "testing"); jsc = new JavaSparkContext(sc); context = new TestSQLContext(sc); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 7bd5b9c73865..4330044e31ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,21 +18,19 @@ package org.apache.spark.sql import scala.concurrent.duration._ -import scala.language.{implicitConversions, postfixOps} +import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.columnar._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils private case class BigData(s: String) -class CachedTableSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ +class CachedTableSuite extends QueryTest with SQLTestUtils { + import testImplicits._ def rddIdOf(tableName: String): Int = { val executedPlan = ctx.table(tableName).queryExecution.executedPlan @@ -50,9 +48,9 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("cache temp table") { testData.select('key).registerTempTable("tempTable") - assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) + assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable"), 0) ctx.cacheTable("tempTable") - assertCached(sql("SELECT COUNT(*) FROM tempTable")) + assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable")) ctx.uncacheTable("tempTable") } @@ -71,8 +69,8 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("cache table as select") { - sql("CACHE TABLE tempTable AS SELECT key FROM testData") - assertCached(sql("SELECT COUNT(*) FROM tempTable")) + ctx.sql("CACHE TABLE tempTable AS SELECT key FROM testData") + assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable")) ctx.uncacheTable("tempTable") } @@ -81,14 +79,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { testData.select('key).registerTempTable("tempTable2") ctx.cacheTable("tempTable1") - assertCached(sql("SELECT COUNT(*) FROM tempTable1")) - assertCached(sql("SELECT COUNT(*) FROM tempTable2")) + assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable1")) + assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? ctx.uncacheTable("tempTable2") // Should this be cached? - assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) + assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable1"), 0) } test("too big for memory") { @@ -169,26 +167,26 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("SELECT star from cached table") { - sql("SELECT * FROM testData").registerTempTable("selectStar") + ctx.sql("SELECT * FROM testData").registerTempTable("selectStar") ctx.cacheTable("selectStar") checkAnswer( - sql("SELECT * FROM selectStar WHERE key = 1"), + ctx.sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) ctx.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = - sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() + ctx.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() ctx.cacheTable("testData") checkAnswer( - sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), + ctx.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) ctx.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { - sql("CACHE TABLE testData") + ctx.sql("CACHE TABLE testData") assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") @@ -196,7 +194,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - sql("UNCACHE TABLE testData") + ctx.sql("UNCACHE TABLE testData") assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { @@ -205,7 +203,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { - sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + ctx.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") @@ -220,7 +218,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("CACHE TABLE tableName AS SELECT ...") { - sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") + ctx.sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") @@ -235,7 +233,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("CACHE LAZY TABLE tableName") { - sql("CACHE LAZY TABLE testData") + ctx.sql("CACHE LAZY TABLE testData") assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") @@ -243,7 +241,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { !isMaterialized(rddId), "Lazily cached in-memory table shouldn't be materialized eagerly") - sql("SELECT COUNT(*) FROM testData").collect() + ctx.sql("SELECT COUNT(*) FROM testData").collect() assert( isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") @@ -255,7 +253,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("InMemoryRelation statistics") { - sql("CACHE TABLE testData") + ctx.sql("CACHE TABLE testData") ctx.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum @@ -284,24 +282,24 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("Clear all cache") { - sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") - sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + ctx.sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + ctx.sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") ctx.cacheTable("t1") ctx.cacheTable("t2") ctx.clearCache() assert(ctx.cacheManager.isEmpty) - sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") - sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + ctx.sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + ctx.sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") ctx.cacheTable("t1") ctx.cacheTable("t2") - sql("Clear CACHE") + ctx.sql("Clear CACHE") assert(ctx.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { - sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") - sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + ctx.sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + ctx.sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") Accumulators.synchronized { val accsSize = Accumulators.originals.size @@ -310,10 +308,10 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { assert((accsSize + 2) == Accumulators.originals.size) } - sql("SELECT * FROM t1").count() - sql("SELECT * FROM t2").count() - sql("SELECT * FROM t1").count() - sql("SELECT * FROM t2").count() + ctx.sql("SELECT * FROM t1").count() + ctx.sql("SELECT * FROM t2").count() + ctx.sql("SELECT * FROM t1").count() + ctx.sql("SELECT * FROM t2").count() Accumulators.synchronized { val accsSize = Accumulators.originals.size diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 695d7c284b18..df49af2f1996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -25,9 +25,16 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.SQLTestUtils class ColumnExpressionSuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ + import testImplicits._ + + private lazy val booleanData = { + ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + } test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") @@ -356,13 +363,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) } - val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( - Row(false, false) :: - Row(false, true) :: - Row(true, false) :: - Row(true, true) :: Nil), - StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) - test("&&") { checkAnswer( booleanData.filter($"a" && true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7238b8dc1ee8..29e0b0805c20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -18,14 +18,12 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.DecimalType -class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ +class DataFrameAggregateSuite extends QueryTest with SQLTestUtils { + import testImplicits._ test("groupBy") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 903ca6fa9475..aa2dd9a7382f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ +class DataFrameFunctionsSuite extends QueryTest with SQLTestUtils { + import testImplicits._ test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 6205a73f03dd..b52a222c0a57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ +class DataFrameImplicitsSuite extends QueryTest with SQLTestUtils { + import testImplicits._ test("RDD of tuples") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 216a24a533dd..49543ea5b481 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -19,12 +19,10 @@ package org.apache.spark.sql import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class DataFrameJoinSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ +class DataFrameJoinSuite extends QueryTest with SQLTestUtils { + import testImplicits._ test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 97302bf88d8b..eb54287558ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ +class DataFrameNaFunctionsSuite extends QueryTest with SQLTestUtils { + import testImplicits._ def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8ea9fa9f88ac..c02d547c65d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -22,11 +22,10 @@ import java.util.Random import org.scalatest.Matchers._ import org.apache.spark.sql.functions.col -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class DataFrameStatSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ +class DataFrameStatSuite extends QueryTest with SQLTestUtils { + import testImplicits._ private def toLetter(i: Int): String = (i + 97).toChar.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7d744e11fa25..97783ee0de0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,9 +31,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} class DataFrameSuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ + import testImplicits._ test("analysis error should be eagerly reported") { // Eager analysis. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 3ae70c03e392..b2aacc04755b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -28,8 +28,7 @@ import org.apache.spark.sql.types._ * end-to-end test infra. In the long run this should just go away. */ class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContext - import ctx.implicits._ + import testImplicits._ test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 586c3736556b..bc0f9ca33abd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,12 +22,11 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.unsafe.types.CalendarInterval -class DateFunctionsSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ +class DateFunctionsSuite extends QueryTest with SQLTestUtils { + import testImplicits._ test("function current_date") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5c7b670c5566..a3d6b270e775 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,11 +25,9 @@ import org.apache.spark.sql.test.SQLTestUtils class JoinSuite extends QueryTest with BeforeAndAfterEach with SQLTestUtils { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ + import testImplicits._ - ctx.loadTestData() + loadTestData() test("equi-join is hash-join") { val x = testData2.as("x") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 87e37837c6e3..d73f2f4e02dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ +class ListTablesSuite extends QueryTest with BeforeAndAfter with SQLTestUtils { + import testImplicits._ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index b05c0def05d5..617ce83d7ff1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest with SharedSQLContext { +class MathExpressionsSuite extends QueryTest with SQLTestUtils { import MathExpressionsTestData._ - - private val ctx = sqlContext - import ctx.implicits._ + import testImplicits._ private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index c30f44e807e2..b22944e37ef4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ +class RowSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 508105a0d056..bcb6bd887bcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class SQLConfSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext +class SQLConfSuite extends QueryTest with SQLTestUtils { private val testKey = "test.key.0" private val testVal = "test.val.0" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 038e9b069cab..23fe33d61f80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class SQLContextSuite extends SparkFunSuite with SharedSQLContext { - private val ctx = sqlContext +class SQLContextSuite extends SparkFunSuite with SQLTestUtils { override def afterAll(): Unit = { SQLContext.setLastInstantiatedContext(ctx) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5f6619b1f09d..923bceda0981 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -33,35 +33,34 @@ import org.apache.spark.sql.types._ class MyDialect extends DefaultParserDialect class SQLQuerySuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ + import testImplicits._ - ctx.loadTestData() + loadTestData() test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") checkAnswer( - sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"), + ctx.sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"), Row("one", 6) :: Row("three", 3) :: Nil) } test("SPARK-8010: promote numeric to string") { val df = Seq((1, 1)).toDF("key", "value") df.registerTempTable("src") - val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ") - val queryCoalesce = sql("select coalesce(null, 1, '1') from src ") + val queryCaseWhen = ctx.sql("select case when true then 1.0 else '1' end from src ") + val queryCoalesce = ctx.sql("select coalesce(null, 1, '1') from src ") checkAnswer(queryCaseWhen, Row("1.0") :: Nil) checkAnswer(queryCoalesce, Row("1") :: Nil) } test("show functions") { - checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + checkAnswer(ctx.sql("SHOW functions"), + FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) } test("describe functions") { - checkExistence(sql("describe function extended upper"), true, + checkExistence(ctx.sql("describe function extended upper"), true, "Function: upper", "Class: org.apache.spark.sql.catalyst.expressions.Upper", "Usage: upper(str) - Returns str with all characters changed to uppercase", @@ -69,15 +68,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { "> SELECT upper('SparkSql');", "'SPARKSQL'") - checkExistence(sql("describe functioN Upper"), true, + checkExistence(ctx.sql("describe functioN Upper"), true, "Function: upper", "Class: org.apache.spark.sql.catalyst.expressions.Upper", "Usage: upper(str) - Returns str with all characters changed to uppercase") - checkExistence(sql("describe functioN Upper"), false, + checkExistence(ctx.sql("describe functioN Upper"), false, "Extended Usage") - checkExistence(sql("describe functioN abcadf"), true, + checkExistence(ctx.sql("describe functioN abcadf"), true, "Function: abcadf is not found.") } @@ -90,7 +89,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { sqlContext.cacheTable("cachedData") checkAnswer( - sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), + ctx.sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) } @@ -98,7 +97,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") checkAnswer( - sql( + ctx.sql( """ |SELECT x.str, COUNT(*) |FROM df x JOIN df y ON x.str = y.str @@ -109,7 +108,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("support table.star") { checkAnswer( - sql( + ctx.sql( """ |SELECT r.* |FROM testData l join testData2 r on (l.key = r.a) @@ -126,7 +125,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { .registerTempTable("df") checkAnswer( - sql( + ctx.sql( """ |SELECT x.str, SUM(x.strCount) |FROM df x JOIN df y ON x.str = y.str @@ -165,7 +164,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( - sql("SELECT a FROM testData2 SORT BY a"), + ctx.sql("SELECT a FROM testData2 SORT BY a"), Seq(1, 1, 2, 2, 3, 3).map(Row(_)) ) } @@ -201,7 +200,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { .registerTempTable("rows") checkAnswer( - sql( + ctx.sql( """ |select attribute, sum(cnt) |from ( @@ -220,7 +219,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { .registerTempTable("d") checkAnswer( - sql("select * from d where d.a in (1,2)"), + ctx.sql("select * from d where d.a in (1,2)"), Seq(Row("1"), Row("2"))) } @@ -228,13 +227,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { checkAnswer( - sql("select sum(a), avg(a) from allNulls"), + ctx.sql("select sum(a), avg(a) from allNulls"), Seq(Row(null, null)) ) } withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { checkAnswer( - sql("select sum(a), avg(a) from allNulls"), + ctx.sql("select sum(a), avg(a) from allNulls"), Seq(Row(null, null)) ) } @@ -242,13 +241,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { checkAnswer( - sql("select sum(a), avg(a) from allNulls"), + ctx.sql("select sum(a), avg(a) from allNulls"), Seq(Row(null, null)) ) } withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { checkAnswer( - sql("select sum(a), avg(a) from allNulls"), + ctx.sql("select sum(a), avg(a) from allNulls"), Seq(Row(null, null)) ) } @@ -256,7 +255,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { } private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { - val df = sql(sqlText) + val df = ctx.sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.executedPlan .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true } @@ -358,82 +357,82 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("Add Parser of SQL COALESCE()") { checkAnswer( - sql("""SELECT COALESCE(1, 2)"""), + ctx.sql("""SELECT COALESCE(1, 2)"""), Row(1)) checkAnswer( - sql("SELECT COALESCE(null, 1, 1.5)"), + ctx.sql("SELECT COALESCE(null, 1, 1.5)"), Row(BigDecimal(1))) checkAnswer( - sql("SELECT COALESCE(null, null, null)"), + ctx.sql("SELECT COALESCE(null, null, null)"), Row(null)) } test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( - sql("SELECT LAST(n) FROM lowerCaseData"), + ctx.sql("SELECT LAST(n) FROM lowerCaseData"), Row(4)) } test("SPARK-2041 column name equals tablename") { checkAnswer( - sql("SELECT tableName FROM tableName"), + ctx.sql("SELECT tableName FROM tableName"), Row("test")) } test("SQRT") { checkAnswer( - sql("SELECT SQRT(key) FROM testData"), + ctx.sql("SELECT SQRT(key) FROM testData"), (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq ) } test("SQRT with automatic string casts") { checkAnswer( - sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"), + ctx.sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"), (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq ) } test("SPARK-2407 Added Parser of SQL SUBSTR()") { checkAnswer( - sql("SELECT substr(tableName, 1, 2) FROM tableName"), + ctx.sql("SELECT substr(tableName, 1, 2) FROM tableName"), Row("te")) checkAnswer( - sql("SELECT substr(tableName, 3) FROM tableName"), + ctx.sql("SELECT substr(tableName, 3) FROM tableName"), Row("st")) checkAnswer( - sql("SELECT substring(tableName, 1, 2) FROM tableName"), + ctx.sql("SELECT substring(tableName, 1, 2) FROM tableName"), Row("te")) checkAnswer( - sql("SELECT substring(tableName, 3) FROM tableName"), + ctx.sql("SELECT substring(tableName, 3) FROM tableName"), Row("st")) } test("SPARK-3173 Timestamp support in the parser") { (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time").registerTempTable("timestamps") - checkAnswer(sql( + checkAnswer(ctx.sql( "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00"))) - checkAnswer(sql( + checkAnswer(ctx.sql( "SELECT time FROM timestamps WHERE time=CAST('1969-12-31 16:00:00.001' AS TIMESTAMP)"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) - checkAnswer(sql( + checkAnswer(ctx.sql( "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.001'"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) - checkAnswer(sql( + checkAnswer(ctx.sql( "SELECT time FROM timestamps WHERE '1969-12-31 16:00:00.001'=time"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) - checkAnswer(sql( + checkAnswer(ctx.sql( """SELECT time FROM timestamps WHERE time<'1969-12-31 16:00:00.003' AND time>'1969-12-31 16:00:00.001'"""), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))) - checkAnswer(sql( + checkAnswer(ctx.sql( """ |SELECT time FROM timestamps |WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002') @@ -441,39 +440,41 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))) - checkAnswer(sql( + checkAnswer(ctx.sql( "SELECT time FROM timestamps WHERE time='123'"), Nil) } test("index into array") { checkAnswer( - sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), + ctx.sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect()) } test("left semi greater than predicate") { checkAnswer( - sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), + ctx.sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), Seq(Row(3, 1), Row(3, 2)) ) } test("left semi greater than predicate and equal operator") { checkAnswer( - sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"), + ctx.sql( + "SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"), Seq(Row(3, 1), Row(3, 2)) ) checkAnswer( - sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"), + ctx.sql( + "SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"), Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)) ) } test("index into array of arrays") { checkAnswer( - sql( + ctx.sql( "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"), arrayData.map(d => Row(d.nestedData, @@ -483,28 +484,28 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("agg") { checkAnswer( - sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), + ctx.sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } test("literal in agg grouping expressions") { def literalInAggTest(): Unit = { checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + ctx.sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + ctx.sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + ctx.sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + ctx.sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + ctx.sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + ctx.sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) checkAnswer( - sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) + ctx.sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + ctx.sql("SELECT 1, 2, sum(b) FROM testData2")) } literalInAggTest() @@ -515,62 +516,62 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), + ctx.sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), Row(1, 3, 2, 6, 3) ) } test("select *") { checkAnswer( - sql("SELECT * FROM testData"), + ctx.sql("SELECT * FROM testData"), testData.collect().toSeq) } test("simple select") { checkAnswer( - sql("SELECT value FROM testData WHERE key = 1"), + ctx.sql("SELECT value FROM testData WHERE key = 1"), Row("1")) } def sortTest(): Unit = { checkAnswer( - sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), + ctx.sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( - sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), + ctx.sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( - sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), + ctx.sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( - sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), + ctx.sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( - sql("SELECT b FROM binaryData ORDER BY a ASC"), + ctx.sql("SELECT b FROM binaryData ORDER BY a ASC"), (1 to 5).map(Row(_))) checkAnswer( - sql("SELECT b FROM binaryData ORDER BY a DESC"), + ctx.sql("SELECT b FROM binaryData ORDER BY a DESC"), (1 to 5).map(Row(_)).toSeq.reverse) checkAnswer( - sql("SELECT * FROM arrayData ORDER BY data[0] ASC"), + ctx.sql("SELECT * FROM arrayData ORDER BY data[0] ASC"), arrayData.collect().sortBy(_.data(0)).map(Row.fromTuple).toSeq) checkAnswer( - sql("SELECT * FROM arrayData ORDER BY data[0] DESC"), + ctx.sql("SELECT * FROM arrayData ORDER BY data[0] DESC"), arrayData.collect().sortBy(_.data(0)).reverse.map(Row.fromTuple).toSeq) checkAnswer( - sql("SELECT * FROM mapData ORDER BY data[1] ASC"), + ctx.sql("SELECT * FROM mapData ORDER BY data[1] ASC"), mapData.collect().sortBy(_.data(1)).map(Row.fromTuple).toSeq) checkAnswer( - sql("SELECT * FROM mapData ORDER BY data[1] DESC"), + ctx.sql("SELECT * FROM mapData ORDER BY data[1] DESC"), mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } @@ -602,25 +603,25 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("limit") { checkAnswer( - sql("SELECT * FROM testData LIMIT 10"), + ctx.sql("SELECT * FROM testData LIMIT 10"), testData.take(10).toSeq) checkAnswer( - sql("SELECT * FROM arrayData LIMIT 1"), + ctx.sql("SELECT * FROM arrayData LIMIT 1"), arrayData.collect().take(1).map(Row.fromTuple).toSeq) checkAnswer( - sql("SELECT * FROM mapData LIMIT 1"), + ctx.sql("SELECT * FROM mapData LIMIT 1"), mapData.collect().take(1).map(Row.fromTuple).toSeq) } test("CTE feature") { checkAnswer( - sql("with q1 as (select * from testData limit 10) select * from q1"), + ctx.sql("with q1 as (select * from testData limit 10) select * from q1"), testData.take(10).toSeq) checkAnswer( - sql(""" + ctx.sql(""" |with q1 as (select * from testData where key= '5'), |q2 as (select * from testData where key = '4') |select * from q1 union all select * from q2""".stripMargin), @@ -630,20 +631,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("Allow only a single WITH clause per query") { intercept[RuntimeException] { - sql( + ctx.sql( "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } } test("date row") { - checkAnswer(sql( + checkAnswer(ctx.sql( """select cast("2015-01-28" as date) from testData limit 1"""), Row(java.sql.Date.valueOf("2015-01-28")) ) } test("from follow multiple brackets") { - checkAnswer(sql( + checkAnswer(ctx.sql( """ |select key from ((select * from testData limit 1) | union all (select * from testData limit 1)) x limit 1 @@ -651,12 +652,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Row(1) ) - checkAnswer(sql( + checkAnswer(ctx.sql( "select key from (select * from testData) x limit 1"), Row(1) ) - checkAnswer(sql( + checkAnswer(ctx.sql( """ |select key from | (select * from testData limit 1 union all select * from testData limit 1) x @@ -668,50 +669,48 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("average") { checkAnswer( - sql("SELECT AVG(a) FROM testData2"), + ctx.sql("SELECT AVG(a) FROM testData2"), Row(2.0)) } test("average overflow") { checkAnswer( - sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), + ctx.sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), Seq(Row(2147483645.0, 1), Row(2.0, 2))) } test("count") { checkAnswer( - sql("SELECT COUNT(*) FROM testData2"), + ctx.sql("SELECT COUNT(*) FROM testData2"), Row(testData2.count())) } test("count distinct") { checkAnswer( - sql("SELECT COUNT(DISTINCT b) FROM testData2"), + ctx.sql("SELECT COUNT(DISTINCT b) FROM testData2"), Row(2)) } test("approximate count distinct") { checkAnswer( - sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), + ctx.sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), Row(3)) } test("approximate count distinct with user provided standard deviation") { checkAnswer( - sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), + ctx.sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), Row(3)) } test("null count") { checkAnswer( - sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), + ctx.sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), Seq(Row(1, 0), Row(2, 1))) checkAnswer( - sql( - """ - |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 - """.stripMargin), + ctx.sql( + "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), Row(2, 1, 2, 2, 1)) } @@ -719,14 +718,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { withTempTable("t") { Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t") checkAnswer( - sql("select count(a) from t"), + ctx.sql("select count(a) from t"), Row(0)) } } test("inner join where, one match per row") { checkAnswer( - sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), + ctx.sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -736,7 +735,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("inner join ON, one match per row") { checkAnswer( - sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"), + ctx.sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -746,7 +745,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("inner join, where, multiple matches") { checkAnswer( - sql(""" + ctx.sql(""" |SELECT * FROM | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y @@ -759,7 +758,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("inner join, no matches") { checkAnswer( - sql( + ctx.sql( """ |SELECT * FROM | (SELECT * FROM testData2 WHERE a = 1) x JOIN @@ -770,7 +769,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("big inner join, 4 matches per row") { checkAnswer( - sql( + ctx.sql( """ |SELECT * FROM | (SELECT * FROM testData UNION ALL @@ -797,7 +796,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("left outer join") { checkAnswer( - sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"), + ctx.sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -808,7 +807,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("right outer join") { checkAnswer( - sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"), + ctx.sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -819,7 +818,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("full outer join") { checkAnswer( - sql( + ctx.sql( """ |SELECT * FROM | (SELECT * FROM upperCaseData WHERE N <= 4) leftTable FULL OUTER JOIN @@ -835,25 +834,25 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { } test("SPARK-3349 partitioning after limit") { - sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") + ctx.sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) .registerTempTable("subset1") - sql("SELECT DISTINCT n FROM lowerCaseData") + ctx.sql("SELECT DISTINCT n FROM lowerCaseData") .limit(2) .registerTempTable("subset2") checkAnswer( - sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), + ctx.sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), Row(3, "c", 3) :: Row(4, "d", 4) :: Nil) checkAnswer( - sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), + ctx.sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), Row(1, "a", 1) :: Row(2, "b", 2) :: Nil) } test("mixed-case keywords") { checkAnswer( - sql( + ctx.sql( """ |SeleCT * from | (select * from upperCaseData WherE N <= 4) leftTable fuLL OUtER joiN @@ -870,13 +869,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("select with table name as qualifier") { checkAnswer( - sql("SELECT testData.value FROM testData WHERE testData.key = 1"), + ctx.sql("SELECT testData.value FROM testData WHERE testData.key = 1"), Row("1")) } test("inner join ON with table name as qualifier") { checkAnswer( - sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"), + ctx.sql( + "SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -886,7 +886,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("qualified select with inner join ON with table name as qualifier") { checkAnswer( - sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " + + ctx.sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " + "ON lowerCaseData.n = upperCaseData.N"), Seq( Row(1, "A"), @@ -897,7 +897,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("system function upper()") { checkAnswer( - sql("SELECT n,UPPER(l) FROM lowerCaseData"), + ctx.sql("SELECT n,UPPER(l) FROM lowerCaseData"), Seq( Row(1, "A"), Row(2, "B"), @@ -905,7 +905,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Row(4, "D"))) checkAnswer( - sql("SELECT n, UPPER(s) FROM nullStrings"), + ctx.sql("SELECT n, UPPER(s) FROM nullStrings"), Seq( Row(1, "ABC"), Row(2, "ABC"), @@ -914,7 +914,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("system function lower()") { checkAnswer( - sql("SELECT N,LOWER(L) FROM upperCaseData"), + ctx.sql("SELECT N,LOWER(L) FROM upperCaseData"), Seq( Row(1, "a"), Row(2, "b"), @@ -924,7 +924,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Row(6, "f"))) checkAnswer( - sql("SELECT n, LOWER(s) FROM nullStrings"), + ctx.sql("SELECT n, LOWER(s) FROM nullStrings"), Seq( Row(1, "abc"), Row(2, "abc"), @@ -933,14 +933,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("UNION") { checkAnswer( - sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), + ctx.sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) checkAnswer( - sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), + ctx.sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( - sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), + ctx.sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), Row(1, "a") :: Row(1, "a") :: Row(2, "b") :: Row(2, "b") :: Row(3, "c") :: Row(3, "c") :: Row(4, "d") :: Row(4, "d") :: Nil) } @@ -948,63 +948,63 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("UNION with column mismatches") { // Column name mismatches are allowed. checkAnswer( - sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), + ctx.sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) // Column type mismatches are not allowed, forcing a type coercion. checkAnswer( - sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), + ctx.sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) // Column type mismatches where a coercion is not possible, in this case between integer // and array types, trigger a TreeNodeException. intercept[AnalysisException] { - sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect() + ctx.sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect() } } test("EXCEPT") { checkAnswer( - sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), + ctx.sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( - sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) + ctx.sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) checkAnswer( - sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil) + ctx.sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil) } test("INTERSECT") { checkAnswer( - sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), + ctx.sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( - sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil) + ctx.sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil) } - test("SET commands semantics using sql()") { + test("SET commands semantics using ctx.sql()") { sqlContext.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" // "set" itself returns all config variables currently specified in SQLConf. - assert(sql("SET").collect().size == 0) + assert(ctx.sql("SET").collect().size == 0) // "set key=val" - sql(s"SET $testKey=$testVal") + ctx.sql(s"SET $testKey=$testVal") checkAnswer( - sql("SET"), + ctx.sql("SET"), Row(testKey, testVal) ) - sql(s"SET ${testKey + testKey}=${testVal + testVal}") + ctx.sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( - sql("set"), + ctx.sql("set"), Seq( Row(testKey, testVal), Row(testKey + testKey, testVal + testVal)) @@ -1012,11 +1012,11 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { // "set key" checkAnswer( - sql(s"SET $testKey"), + ctx.sql(s"SET $testKey"), Row(testKey, testVal) ) checkAnswer( - sql(s"SET $nonexistentKey"), + ctx.sql(s"SET $nonexistentKey"), Row(nonexistentKey, "") ) sqlContext.conf.clear() @@ -1026,9 +1026,9 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { sqlContext.conf.clear() // Set negative mapred.reduce.tasks for automatically determing // the number of reducers is not supported - intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) - intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) - intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) + intercept[IllegalArgumentException](ctx.sql(s"SET mapred.reduce.tasks=-1")) + intercept[IllegalArgumentException](ctx.sql(s"SET mapred.reduce.tasks=-01")) + intercept[IllegalArgumentException](ctx.sql(s"SET mapred.reduce.tasks=-2")) sqlContext.conf.clear() } @@ -1050,14 +1050,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( - sql("SELECT * FROM applySchema1"), + ctx.sql("SELECT * FROM applySchema1"), Row(1, "A1", true, null) :: Row(2, "B2", false, null) :: Row(3, "C3", true, null) :: Row(4, "D4", true, 2147483644) :: Nil) checkAnswer( - sql("SELECT f1, f4 FROM applySchema1"), + ctx.sql("SELECT f1, f4 FROM applySchema1"), Row(1, null) :: Row(2, null) :: Row(3, null) :: @@ -1080,14 +1080,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val df2 = sqlContext.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( - sql("SELECT * FROM applySchema2"), + ctx.sql("SELECT * FROM applySchema2"), Row(Row(1, true), Map("A1" -> null)) :: Row(Row(2, false), Map("B2" -> null)) :: Row(Row(3, true), Map("C3" -> null)) :: Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil) checkAnswer( - sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), + ctx.sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), Row(1, null) :: Row(2, null) :: Row(3, null) :: @@ -1106,7 +1106,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { df3.registerTempTable("applySchema3") checkAnswer( - sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), + ctx.sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), Row(1, null) :: Row(2, null) :: Row(3, null) :: @@ -1115,17 +1115,17 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("SPARK-3423 BETWEEN") { checkAnswer( - sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"), + ctx.sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"), Seq(Row(5, "5"), Row(6, "6"), Row(7, "7")) ) checkAnswer( - sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"), + ctx.sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"), Row(7, "7") ) checkAnswer( - sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), + ctx.sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), Nil ) } @@ -1133,12 +1133,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("cast boolean to string") { // TODO Ensure true/false string letter casing is consistent with Hive in all cases. checkAnswer( - sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), + ctx.sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), Row("true", "false")) } test("metadata is propagated correctly") { - val person: DataFrame = sql("SELECT * FROM person") + val person: DataFrame = ctx.sql("SELECT * FROM person") val schema = person.schema val docKey = "doc" val docValue = "first name" @@ -1155,39 +1155,41 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { validateMetadata(personWithMeta.select($"name")) validateMetadata(personWithMeta.select($"name")) validateMetadata(personWithMeta.select($"id", $"name")) - validateMetadata(sql("SELECT * FROM personWithMeta")) - validateMetadata(sql("SELECT id, name FROM personWithMeta")) - validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) - validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(ctx.sql("SELECT * FROM personWithMeta")) + validateMetadata(ctx.sql("SELECT id, name FROM personWithMeta")) + validateMetadata(ctx.sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(ctx.sql( + "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } test("SPARK-3371 Renaming a function expression with group by gives error") { sqlContext.udf.register("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), + ctx.sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) } test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { checkAnswer( - sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), + ctx.sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), Row(1)) } test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { checkAnswer( - sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), + ctx.sql( + "SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), Row(1)) } test("throw errors for non-aggregate attributes with aggregation") { def checkAggregation(query: String, isInvalidQuery: Boolean = true) { if (isInvalidQuery) { - val e = intercept[AnalysisException](sql(query).queryExecution.analyzed) + val e = intercept[AnalysisException](ctx.sql(query).queryExecution.analyzed) assert(e.getMessage contains "group by") } else { // Should not throw - sql(query).queryExecution.analyzed + ctx.sql(query).queryExecution.analyzed } } @@ -1203,137 +1205,137 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("Test to check we can use Long.MinValue") { checkAnswer( - sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue) + ctx.sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue) ) checkAnswer( - sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), + ctx.sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), (1 to 100).map(Row(_)).toSeq ) } test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) + ctx.sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) ) checkAnswer( - sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) + ctx.sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) ) checkAnswer( - sql("SELECT .5"), Row(BigDecimal(0.5)) + ctx.sql("SELECT .5"), Row(BigDecimal(0.5)) ) checkAnswer( - sql("SELECT -.18"), Row(BigDecimal(-0.18)) + ctx.sql("SELECT -.18"), Row(BigDecimal(-0.18)) ) } test("Auto cast integer type") { checkAnswer( - sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L) + ctx.sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L) ) checkAnswer( - sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L) + ctx.sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L) ) checkAnswer( - sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) + ctx.sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) ) checkAnswer( - sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) + ctx.sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) ) } test("Test to check we can apply sign to expression") { checkAnswer( - sql("SELECT -100"), Row(-100) + ctx.sql("SELECT -100"), Row(-100) ) checkAnswer( - sql("SELECT +230"), Row(230) + ctx.sql("SELECT +230"), Row(230) ) checkAnswer( - sql("SELECT -5.2"), Row(BigDecimal(-5.2)) + ctx.sql("SELECT -5.2"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(BigDecimal(6.8)) + ctx.sql("SELECT +6.8"), Row(BigDecimal(6.8)) ) checkAnswer( - sql("SELECT -key FROM testData WHERE key = 2"), Row(-2) + ctx.sql("SELECT -key FROM testData WHERE key = 2"), Row(-2) ) checkAnswer( - sql("SELECT +key FROM testData WHERE key = 3"), Row(3) + ctx.sql("SELECT +key FROM testData WHERE key = 3"), Row(3) ) checkAnswer( - sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2) + ctx.sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2) ) checkAnswer( - sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9) + ctx.sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9) ) checkAnswer( - sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10) + ctx.sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10) ) checkAnswer( - sql("SELECT -MAX(key) FROM testData"), Row(-100) + ctx.sql("SELECT -MAX(key) FROM testData"), Row(-100) ) checkAnswer( - sql("SELECT +MAX(key) FROM testData"), Row(100) + ctx.sql("SELECT +MAX(key) FROM testData"), Row(100) ) checkAnswer( - sql("SELECT - (-10)"), Row(10) + ctx.sql("SELECT - (-10)"), Row(10) ) checkAnswer( - sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32) + ctx.sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32) ) checkAnswer( - sql("SELECT - (+Max(key)) FROM testData"), Row(-100) + ctx.sql("SELECT - (+Max(key)) FROM testData"), Row(-100) ) checkAnswer( - sql("SELECT - - 3"), Row(3) + ctx.sql("SELECT - - 3"), Row(3) ) checkAnswer( - sql("SELECT - + 20"), Row(-20) + ctx.sql("SELECT - + 20"), Row(-20) ) checkAnswer( - sql("SELEcT - + 45"), Row(-45) + ctx.sql("SELEcT - + 45"), Row(-45) ) checkAnswer( - sql("SELECT + + 100"), Row(100) + ctx.sql("SELECT + + 100"), Row(100) ) checkAnswer( - sql("SELECT - - Max(key) FROM testData"), Row(100) + ctx.sql("SELECT - - Max(key) FROM testData"), Row(100) ) checkAnswer( - sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33) + ctx.sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33) ) } test("Multiple join") { checkAnswer( - sql( + ctx.sql( """SELECT a.key, b.key, c.key |FROM testData a |JOIN testData b ON a.key = b.key @@ -1346,28 +1348,28 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val data = sqlContext.sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) sqlContext.read.json(data).registerTempTable("records") - sql("SELECT `key?number1`, `key.number2` FROM records") + ctx.sql("SELECT `key?number1`, `key.number2` FROM records") } test("SPARK-3814 Support Bitwise & operator") { - checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), Row(1)) + checkAnswer(ctx.sql("SELECT key&1 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise | operator") { - checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), Row(1)) + checkAnswer(ctx.sql("SELECT key|0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ^ operator") { - checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), Row(1)) + checkAnswer(ctx.sql("SELECT key^0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ~ operator") { - checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), Row(-2)) + checkAnswer(ctx.sql("SELECT ~key FROM testData WHERE key = 1 "), Row(-2)) } test("SPARK-4120 Join of multiple tables does not work in SparkSQL") { checkAnswer( - sql( + ctx.sql( """SELECT a.key, b.key, c.key |FROM testData a,testData b,testData c |where a.key = b.key and a.key = c.key @@ -1376,37 +1378,37 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { } test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { - checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), + checkAnswer(ctx.sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), (11 to 100).map(i => Row(i))) } test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { - checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), + checkAnswer(ctx.sql("SELECT key FROM testData WHERE value not like '100%' order by key"), (1 to 99).map(i => Row(i))) } test("SPARK-4322 Grouping field with struct field as sub expression") { sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") - checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) + checkAnswer(ctx.sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) sqlContext.dropTempTable("data") sqlContext.read.json( sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") - checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) + checkAnswer(ctx.sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) sqlContext.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( - sql("SELECT a + b FROM testData2 ORDER BY a"), + ctx.sql("SELECT a + b FROM testData2 ORDER BY a"), Seq(2, 3, 3, 4, 4, 5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( - sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), + ctx.sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)) ) } @@ -1418,7 +1420,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") - checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + + checkAnswer(ctx.sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), (1 to 2).map(i => Row(i))) } @@ -1427,7 +1429,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") - checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) + checkAnswer(ctx.sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } test("SPARK-4699 case sensitivity SQL query") { @@ -1435,7 +1437,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") - checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) + checkAnswer(ctx.sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } @@ -1444,19 +1446,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") - checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) - checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) - checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) - checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) - checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) - checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(ctx.sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(ctx.sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(ctx.sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(ctx.sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(ctx.sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(ctx.sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) } test("SPARK-6145: special cases") { sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") - checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) - checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(ctx.sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(ctx.sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { @@ -1464,7 +1466,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") - checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + checkAnswer(ctx.sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } test("SPARK-6583 order by aggregated function") { @@ -1472,7 +1474,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { .toDF("a", "b").registerTempTable("orderByData") checkAnswer( - sql( + ctx.sql( """ |SELECT a |FROM orderByData @@ -1482,7 +1484,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Row("4") :: Row("1") :: Row("3") :: Row("2") :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT sum(b) |FROM orderByData @@ -1492,7 +1494,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT a, sum(b) |FROM orderByData @@ -1502,7 +1504,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT a, sum(b) |FROM orderByData @@ -1527,8 +1529,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { (null, null, null, true) ).toDF("i", "b", "r1", "r2").registerTempTable("t") - checkAnswer(sql("select i = b from t"), sql("select r1 from t")) - checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) + checkAnswer(ctx.sql("select i = b from t"), ctx.sql("select r1 from t")) + checkAnswer(ctx.sql("select i <=> b from t"), ctx.sql("select r2 from t")) } } @@ -1536,14 +1538,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { withTempTable("t") { sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") - checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) + checkAnswer(ctx.sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } } test("SPARK-8782: ORDER BY NULL") { withTempTable("t") { Seq((1, 2), (1, 2)).toDF("a", "b").registerTempTable("t") - checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) + checkAnswer(ctx.sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) } } @@ -1552,14 +1554,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val df = Seq(1 -> "a").toDF("count", "sort") checkAnswer(df.filter("count > 0"), Row(1, "a")) df.registerTempTable("t") - checkAnswer(sql("select count, sort from t"), Row(1, "a")) + checkAnswer(ctx.sql("select count, sort from t"), Row(1, "a")) } } test("SPARK-8753: add interval type") { import org.apache.spark.unsafe.types.CalendarInterval - val df = sql("select interval 3 years -3 month 7 week 123 microseconds") + val df = ctx.sql("select interval 3 years -3 month 7 week 123 microseconds") checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) withTempPath(f => { // Currently we don't yet support saving out values of interval data type. @@ -1571,7 +1573,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { def checkIntervalParseError(s: String): Unit = { val e = intercept[AnalysisException] { - sql(s) + ctx.sql(s) } e.message.contains("at least one time unit should be given for interval literal") } @@ -1585,7 +1587,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.unsafe.types.CalendarInterval.MICROS_PER_WEEK - val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") + val df = ctx.sql("select interval 3 years -3 month 7 week 123 microseconds as i") checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) checkAnswer(df.select(df("i") + new CalendarInterval(2, 123)), @@ -1626,7 +1628,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { .toDF("num", "str") df.registerTempTable("1one") - checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10)) + checkAnswer(ctx.sql("select count(num) from 1one"), Row(10)) sqlContext.dropTempTable("1one") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index e117ca4c511c..e8b8224343b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils case class ReflectData( stringField: String, @@ -72,9 +72,8 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ +class ScalaReflectionRelationSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index f8a89f323047..01430eccecb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.Decimal -class StringFunctionsSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ +class StringFunctionsSuite extends QueryTest with SQLTestUtils { + import testImplicits._ test("string concat") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 2da7874d9d45..da7ff4d49823 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,9 +22,7 @@ import org.apache.spark.sql.test.SQLTestUtils private case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ + import testImplicits._ test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index ddf3c184839b..f3237edcca8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,7 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -67,9 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ +class UserDefinedTypeSuite extends QueryTest with SQLTestUtils { + import testImplicits._ private lazy val pointsRDD = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 0cc127aaa5f3..1befda35ace3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -25,11 +25,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ + import testImplicits._ - ctx.loadTestData() + loadTestData() test("simple columnar query") { val plan = ctx.executePlan(testData.logicalPlan).executedPlan @@ -67,25 +65,25 @@ class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { test("SPARK-1678 regression: compression must not lose repeated values") { checkAnswer( - sql("SELECT * FROM repeatedData"), + ctx.sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) ctx.cacheTable("repeatedData") checkAnswer( - sql("SELECT * FROM repeatedData"), + ctx.sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) } test("with null values") { checkAnswer( - sql("SELECT * FROM nullableRepeatedData"), + ctx.sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) ctx.cacheTable("nullableRepeatedData") checkAnswer( - sql("SELECT * FROM nullableRepeatedData"), + ctx.sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) } @@ -94,25 +92,25 @@ class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { timestamps.registerTempTable("timestamps") checkAnswer( - sql("SELECT time FROM timestamps"), + ctx.sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) ctx.cacheTable("timestamps") checkAnswer( - sql("SELECT time FROM timestamps"), + ctx.sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) } test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { checkAnswer( - sql("SELECT * FROM withEmptyParts"), + ctx.sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) ctx.cacheTable("withEmptyParts") checkAnswer( - sql("SELECT * FROM withEmptyParts"), + ctx.sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) } @@ -134,7 +132,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { df.cache().registerTempTable("test_fixed_decimal") checkAnswer( - sql("SELECT * FROM test_fixed_decimal"), + ctx.sql("SELECT * FROM test_fixed_decimal"), (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) } @@ -180,7 +178,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { } ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. - sql("cache table InMemoryCache_different_data_types") + ctx.sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( @@ -188,7 +186,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( - sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), + ctx.sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), ctx.table("InMemoryCache_different_data_types").collect()) ctx.dropTempTable("InMemoryCache_different_data_types") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 591e1ff4a789..354a7fe98da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -21,12 +21,10 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ +class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with SQLTestUtils { + import testImplicits._ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala index bd022ec26111..48a907d4b1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions._ class AggregateSuite extends SparkPlanTest { - private val ctx = sqlContext test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") { val codegenDefault = ctx.getConf(SQLConf.CODEGEN_ENABLED) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 921e4507ca2a..1cec8f676823 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -28,14 +28,13 @@ import org.apache.spark.sql.{execution, Row, SQLConf} class PlannerSuite extends SparkFunSuite with SQLTestUtils { - private val ctx = sqlContext - import ctx.implicits._ - import ctx.planner._ - import ctx._ + import testImplicits._ - ctx.loadTestData() + loadTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val _ctx = ctx + import _ctx.planner._ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = plannedOption.getOrElse( @@ -50,6 +49,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } test("unions are collapsed") { + val _ctx = ctx + import _ctx.planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -77,16 +78,16 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) val fields = fieldTypes.zipWithIndex.map { case (dataType, index) => StructField(s"c${index}", dataType, true) } :+ StructField("key", IntegerType, true) val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) val rowRDD = ctx.sparkContext.parallelize(row :: Nil) - createDataFrame(rowRDD, schema).registerTempTable("testLimit") + ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit") - val planned = sql( + val planned = ctx.sql( """ |SELECT l.a, l.b |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) @@ -98,10 +99,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - dropTempTable("testLimit") + ctx.dropTempTable("testLimit") } - val origThreshold = conf.autoBroadcastJoinThreshold + val origThreshold = ctx.conf.autoBroadcastJoinThreshold val simpleTypes = NullType :: @@ -133,18 +134,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { checkPlan(complexTypes, newThreshold = 901617) - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("InMemoryRelation statistics propagation") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) + val origThreshold = ctx.conf.autoBroadcastJoinThreshold + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) testData.limit(3).registerTempTable("tiny") - sql("CACHE TABLE tiny") + ctx.sql("CACHE TABLE tiny") val a = testData.as("a") - val b = table("tiny").as("b") + val b = ctx.table("tiny").as("b") val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } @@ -153,12 +154,12 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("efficient limit -> project -> sort") { val query = testData.sort('key).select('value).limit(2).logicalPlan - val planned = planner.TakeOrderedAndProject(query) + val planned = ctx.planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) } @@ -171,7 +172,7 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { // Disable broadcast join withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { { - val numExchanges = sql( + val numExchanges = ctx.sql( """ |SELECT * |FROM @@ -186,7 +187,7 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { { // This second query joins on different keys: - val numExchanges = sql( + val numExchanges = ctx.sql( """ |SELECT * |FROM diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index dc2c7b6fa185..9a2deb18a8e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} class RowFormatConvertersSuite extends SparkPlanTest { - private val ctx = sqlContext private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { case c: ConvertToUnsafe => c diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 94f56b456147..29cb920e6c05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -25,15 +25,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -private[sql] abstract class SparkPlanTest extends AbstractSparkPlanTest with SharedSQLContext { - protected override def _sqlContext: SQLContext = sqlContext -} +private[sql] abstract class SparkPlanTest extends AbstractSparkPlanTest with SQLTestUtils /** * Helper class for testing individual physical operators with a pluggable [[SQLContext]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 8fe68458e824..3535f8f184e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -22,11 +22,11 @@ import java.sql.{Timestamp, Date} import org.apache.spark.serializer.Serializer import org.apache.spark.{ShuffleDependency, SparkFunSuite} import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.Row -class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite with SharedSQLContext { +class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite with SQLTestUtils { // Make sure that we will not use serializer2 for unsupported data types. def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { val testName = @@ -65,9 +65,7 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite with SharedSQLConte checkSupported(new MyDenseVectorUDT, isSupported = false) } -abstract class SparkSqlSerializer2Suite extends QueryTest with SharedSQLContext { - protected val ctx = sqlContext - +abstract class SparkSqlSerializer2Suite extends QueryTest with SQLTestUtils { var allColumns: String = _ val serializerClass: Class[Serializer] = classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 8ce75e8f2318..9c5a03da4153 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.types._ * A test suite that generates randomized data to test the [[TungstenSort]] operator. */ class TungstenSortSuite extends SparkPlanTest { - private val ctx = sqlContext override def beforeAll(): Unit = { ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index b242ff49542a..306e4c8b378c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,16 +18,15 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class DebuggingSuite extends SparkFunSuite with SharedSQLContext { - private val ctx = sqlContext +class DebuggingSuite extends SparkFunSuite with SQLTestUtils { test("DataFrame.debug()") { - ctx.testData.debug() + testData.debug() } test("DataFrame.typeCheck()") { - ctx.testData.typeCheck() + testData.typeCheck() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 04c3606d648f..d1120b2c910d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,14 +25,12 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SQLTestUtils { + import testImplicits._ val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" @@ -66,14 +64,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext "insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() conn.commit() - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE foobar |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE fetchtwo |USING org.apache.spark.sql.jdbc @@ -81,7 +79,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext | fetchSize '2') """.stripMargin.replaceAll("\n", " ")) - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE parts |USING org.apache.spark.sql.jdbc @@ -96,7 +94,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext conn.prepareStatement("insert into test.inttypes values (null, null, null, null, null)" ).executeUpdate() conn.commit() - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE inttypes |USING org.apache.spark.sql.jdbc @@ -113,7 +111,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext stmt.setBytes(5, testBytes) stmt.setString(6, "I am a clob!") stmt.executeUpdate() - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE strtypes |USING org.apache.spark.sql.jdbc @@ -127,7 +125,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext conn.prepareStatement("insert into test.timetypes values ('12:34:56', " + "null, '2002-02-20 11:22:33.543543543')").executeUpdate() conn.commit() - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE timetypes |USING org.apache.spark.sql.jdbc @@ -142,7 +140,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext + "1.00000011920928955078125, " + "123456789012345.543215432154321)").executeUpdate() conn.commit() - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE flttypes |USING org.apache.spark.sql.jdbc @@ -159,7 +157,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext + "null, null, null, null, null, null, null, null, null, " + "null, null, null, null, null, null)").executeUpdate() conn.commit() - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE nulltypes |USING org.apache.spark.sql.jdbc @@ -174,24 +172,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT *") { - assert(sql("SELECT * FROM foobar").collect().size === 3) + assert(ctx.sql("SELECT * FROM foobar").collect().size === 3) } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) + assert(ctx.sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) + assert(ctx.sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) + assert(ctx.sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) + assert(ctx.sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) + assert(ctx.sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) + assert(ctx.sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) } test("SELECT * WHERE (quoted strings)") { - assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) + assert(ctx.sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) } test("SELECT first field") { - val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) + val names = ctx.sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size === 3) assert(names(0).equals("fred")) assert(names(1).equals("joe 'foo' \"bar\"")) @@ -199,7 +197,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT first field when fetchSize is two") { - val names = sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) + val names = ctx.sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size === 3) assert(names(0).equals("fred")) assert(names(1).equals("joe 'foo' \"bar\"")) @@ -207,7 +205,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT second field") { - val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) + val ids = ctx.sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) assert(ids(0) === 1) assert(ids(1) === 2) @@ -215,7 +213,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT second field when fetchSize is two") { - val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) + val ids = ctx.sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) assert(ids(0) === 1) assert(ids(1) === 2) @@ -223,17 +221,17 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT * partitioned") { - assert(sql("SELECT * FROM parts").collect().size == 3) + assert(ctx.sql("SELECT * FROM parts").collect().size == 3) } test("SELECT WHERE (simple predicates) partitioned") { - assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) - assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) - assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) + assert(ctx.sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) + assert(ctx.sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) + assert(ctx.sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) } test("SELECT second field partitioned") { - val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _) + val ids = ctx.sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) assert(ids(0) === 1) assert(ids(1) === 2) @@ -242,7 +240,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("Register JDBC query with renamed fields") { // Regression test for bug SPARK-7345 - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE renamed |USING org.apache.spark.sql.jdbc @@ -250,7 +248,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext |user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - val df = sql("SELECT * FROM renamed") + val df = ctx.sql("SELECT * FROM renamed") assert(df.schema.fields.size == 2) assert(df.schema.fields(0).name == "NAME1") assert(df.schema.fields(1).name == "NAME2") @@ -281,7 +279,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 integral types") { - val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() + val rows = ctx.sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() assert(rows.length === 1) assert(rows(0).getInt(0) === 1) assert(rows(0).getBoolean(1) === false) @@ -291,7 +289,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 null entries") { - val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() + val rows = ctx.sql("SELECT * FROM inttypes WHERE A IS NULL").collect() assert(rows.length === 1) assert(rows(0).isNullAt(0)) assert(rows(0).isNullAt(1)) @@ -301,7 +299,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 string types") { - val rows = sql("SELECT * FROM strtypes").collect() + val rows = ctx.sql("SELECT * FROM strtypes").collect() assert(rows(0).getAs[Array[Byte]](0).sameElements(testBytes)) assert(rows(0).getString(1).equals("Sensitive")) assert(rows(0).getString(2).equals("Insensitive")) @@ -311,7 +309,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 time types") { - val rows = sql("SELECT * FROM timetypes").collect() + val rows = ctx.sql("SELECT * FROM timetypes").collect() val cal = new GregorianCalendar(java.util.Locale.ROOT) cal.setTime(rows(0).getAs[java.sql.Timestamp](0)) assert(cal.get(Calendar.HOUR_OF_DAY) === 12) @@ -345,7 +343,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") - val cachedRows = sql("select * from mycached_date").collect() + val cachedRows = ctx.sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } @@ -357,26 +355,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 floating-point types") { - val rows = sql("SELECT * FROM flttypes").collect() + val rows = ctx.sql("SELECT * FROM flttypes").collect() assert(rows(0).getDouble(0) === 1.00000000000000022) assert(rows(0).getDouble(1) === 1.00000011920928955) assert(rows(0).getAs[BigDecimal](2) === new BigDecimal("123456789012345.543215432154321000")) assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18)) - val result = sql("SELECT C FROM flttypes where C > C - 1").collect() + val result = ctx.sql("SELECT C FROM flttypes where C > C - 1").collect() assert(result(0).getAs[BigDecimal](0) === new BigDecimal("123456789012345.543215432154321000")) } test("SQL query as table name") { - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE hack |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)', | user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - val rows = sql("SELECT * FROM hack").collect() + val rows = ctx.sql("SELECT * FROM hack").collect() assert(rows(0).getDouble(0) === 1.00000011920928955) // Yes, I meant ==. // For some reason, H2 computes this square incorrectly... assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) @@ -386,7 +384,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext // We set rowId to false during setup, which means that _ROWID_ column should be absent from // all tables. If rowId is true (default), the query below doesn't throw an exception. intercept[JdbcSQLException] { - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE abc |USING org.apache.spark.sql.jdbc diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 0284a7a5e858..f2431d892e68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -23,15 +23,12 @@ import java.util.Properties import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{SaveMode, Row} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SQLTestUtils { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 79a25d53b7b6..2b959175d8aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -32,9 +32,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ + import testImplicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -231,7 +229,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql("select nullstr, headers.Host from jsonTable"), + ctx.sql("select nullstr, headers.Host from jsonTable"), Seq(Row("", "1.abc.com"), Row("", null), Row("", null), Row(null, null)) ) } @@ -253,7 +251,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql("select * from jsonTable"), + ctx.sql("select * from jsonTable"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -295,44 +293,44 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // Access elements of a primitive array. checkAnswer( - sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), + ctx.sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( - sql("select arrayOfNull from jsonTable"), + ctx.sql("select arrayOfNull from jsonTable"), Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( - sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), + ctx.sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), Row(new java.math.BigDecimal("922337203685477580700"), new java.math.BigDecimal("-922337203685477580800"), null) ) // Access elements of an array of arrays. checkAnswer( - sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), + ctx.sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( - sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), + ctx.sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( - sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), + ctx.sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), Row("str2", 2.1) ) // Access elements of an array of structs. checkAnswer( - sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + + ctx.sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + "from jsonTable"), Row( Row(true, "str1", null), @@ -343,7 +341,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // Access a struct and fields inside of it. checkAnswer( - sql("select struct, struct.field1, struct.field2 from jsonTable"), + ctx.sql("select struct, struct.field1, struct.field2 from jsonTable"), Row( Row(true, new java.math.BigDecimal("92233720368547758070")), true, @@ -352,13 +350,13 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // Access an array field of a struct. checkAnswer( - sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), + ctx.sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( - sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), + ctx.sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), Row(5, null) ) } @@ -368,13 +366,13 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + ctx.sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), Row(true, "str1") ) // Getting all values of a specific field from an array of structs. checkAnswer( - sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), + ctx.sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), Row(Seq(true, false, null), Seq("str1", null, null)) ) } @@ -395,7 +393,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql("select * from jsonTable"), + ctx.sql("select * from jsonTable"), Row("true", 11L, null, 1.1, "13.1", "str1") :: Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: Row("false", 21474836470L, @@ -406,49 +404,49 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // Number and Boolean conflict: resolve the type as number in this query. checkAnswer( - sql("select num_bool - 10 from jsonTable where num_bool > 11"), + ctx.sql("select num_bool - 10 from jsonTable where num_bool > 11"), Row(2) ) // Widening to LongType checkAnswer( - sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), + ctx.sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), Row(21474836370L) :: Row(21474836470L) :: Nil ) checkAnswer( - sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), + ctx.sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) // Widening to DecimalType checkAnswer( - sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), + ctx.sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), Row(BigDecimal("21474836472.2")) :: Row(BigDecimal("92233720368547758071.3")) :: Nil ) // Widening to Double checkAnswer( - sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), + ctx.sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), Row(101.2) :: Row(21474836471.2) :: Nil ) // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 14"), + ctx.sql("select num_str + 1.2 from jsonTable where num_str > 14"), Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + ctx.sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), Row(new java.math.BigDecimal("92233720368547758071.2")) ) // String and Boolean conflict: resolve the type as string. checkAnswer( - sql("select * from jsonTable where str_bool = 'str1'"), + ctx.sql("select * from jsonTable where str_bool = 'str1'"), Row("true", 11L, null, 1.1, "13.1", "str1") ) } @@ -460,24 +458,24 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // Right now, the analyzer does not promote strings in a boolean expression. // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( - sql("select num_bool from jsonTable where NOT num_bool"), + ctx.sql("select num_bool from jsonTable where NOT num_bool"), Row(false) ) checkAnswer( - sql("select str_bool from jsonTable where NOT str_bool"), + ctx.sql("select str_bool from jsonTable where NOT str_bool"), Row(false) ) // Right now, the analyzer does not know that num_bool should be treated as a boolean. // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( - sql("select num_bool from jsonTable where num_bool"), + ctx.sql("select num_bool from jsonTable where num_bool"), Row(true) ) checkAnswer( - sql("select str_bool from jsonTable where str_bool"), + ctx.sql("select str_bool from jsonTable where str_bool"), Row(false) ) @@ -501,7 +499,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // which is not 14.3. // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 13"), + ctx.sql("select num_str + 1.2 from jsonTable where num_str > 13"), Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil ) } @@ -522,7 +520,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql("select * from jsonTable"), + ctx.sql("select * from jsonTable"), Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: Row(null, """{"field":false}""", null, null, "{}") :: Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: @@ -544,7 +542,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql("select * from jsonTable"), + ctx.sql("select * from jsonTable"), Row(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", """{"field":"str"}"""), Seq(Row(214748364700L), Row(1)), null) :: Row(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: @@ -553,7 +551,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // Treat an element as a number. checkAnswer( - sql("select array1[0] + 1 from jsonTable where array1 is not null"), + ctx.sql("select array1[0] + 1 from jsonTable where array1 is not null"), Row(2) ) } @@ -623,7 +621,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql("select * from jsonTable"), + ctx.sql("select * from jsonTable"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -640,7 +638,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE jsonTableSQL |USING org.apache.spark.sql.json @@ -650,7 +648,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { """.stripMargin) checkAnswer( - sql("select * from jsonTableSQL"), + ctx.sql("select * from jsonTableSQL"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -683,7 +681,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF1.registerTempTable("jsonTable1") checkAnswer( - sql("select * from jsonTable1"), + ctx.sql("select * from jsonTable1"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -700,7 +698,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF2.registerTempTable("jsonTable2") checkAnswer( - sql("select * from jsonTable2"), + ctx.sql("select * from jsonTable2"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -719,7 +717,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") checkAnswer( - sql("select map from jsonWithSimpleMap"), + ctx.sql("select map from jsonWithSimpleMap"), Row(Map("a" -> 1)) :: Row(Map("b" -> 2)) :: Row(Map("c" -> 3)) :: @@ -728,7 +726,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { ) checkAnswer( - sql("select map['c'] from jsonWithSimpleMap"), + ctx.sql("select map['c'] from jsonWithSimpleMap"), Row(null) :: Row(null) :: Row(3) :: @@ -747,7 +745,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonWithComplexMap.registerTempTable("jsonWithComplexMap") checkAnswer( - sql("select map from jsonWithComplexMap"), + ctx.sql("select map from jsonWithComplexMap"), Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) :: Row(Map("b" -> Row(null, 2))) :: Row(Map("c" -> Row(Seq(), 4))) :: @@ -757,7 +755,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { ) checkAnswer( - sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), + ctx.sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), Row(Seq(1, 2, 3, null), null) :: Row(null, null) :: Row(null, 4) :: @@ -772,11 +770,11 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + ctx.sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), Row(true, "str1") ) checkAnswer( - sql( + ctx.sql( """ |select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] |from jsonTable @@ -790,7 +788,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql( + ctx.sql( """ |select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0] |from jsonTable @@ -798,7 +796,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { Row(5, 7, 8) ) checkAnswer( - sql( + ctx.sql( """ |select arrayOfArray2[0][0][0].inner1, arrayOfArray2[1][0], |arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4 @@ -813,7 +811,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { jsonDF.registerTempTable("jsonTable") checkAnswer( - sql( + ctx.sql( """ |select a, b, c |from jsonTable @@ -843,7 +841,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // In HiveContext, backticks should be used to access columns starting with a underscore. checkAnswer( - sql( + ctx.sql( """ |SELECT a, b, c, _unparsed |FROM jsonTable @@ -857,7 +855,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { ) checkAnswer( - sql( + ctx.sql( """ |SELECT a, b, c |FROM jsonTable @@ -867,7 +865,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { ) checkAnswer( - sql( + ctx.sql( """ |SELECT _unparsed |FROM jsonTable @@ -902,7 +900,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { assert(schema === jsonDF.schema) checkAnswer( - sql( + ctx.sql( """ |SELECT field1, field2, field3, field4 |FROM jsonTable @@ -965,7 +963,7 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { val primTable = _sqlContext.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( - sql("select * from primativeTable"), + ctx.sql("select * from primativeTable"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -979,19 +977,19 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( - sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), + ctx.sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( - sql("select arrayOfNull from complexTable"), + ctx.sql("select arrayOfNull from complexTable"), Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( - sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] " + + ctx.sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] " + " from complexTable"), Row(new java.math.BigDecimal("922337203685477580700"), new java.math.BigDecimal("-922337203685477580800"), null) @@ -999,25 +997,25 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // Access elements of an array of arrays. checkAnswer( - sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), + ctx.sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( - sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), + ctx.sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( - sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), + ctx.sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), Row("str2", 2.1) ) // Access a struct and fields inside of it. checkAnswer( - sql("select struct, struct.field1, struct.field2 from complexTable"), + ctx.sql("select struct, struct.field1, struct.field2 from complexTable"), Row( Row(true, new java.math.BigDecimal("92233720368547758070")), true, @@ -1026,13 +1024,13 @@ class JsonSuite extends QueryTest with TestJsonData with SQLTestUtils { // Access an array field of a struct. checkAnswer( - sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), + ctx.sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( - sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] " + + ctx.sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] " + "from complexTable"), Row(5, null) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 997779b4985f..7935e3cd9580 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -40,7 +40,6 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * data type is nullable. */ class ParquetFilterSuite extends QueryTest with ParquetTest { - private val ctx = sqlContext private def checkFilterPredicate( df: DataFrame, @@ -301,7 +300,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("SPARK-6554: don't push down predicates which reference partition columns") { - import ctx.implicits._ + import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 7035bde00d9f..c07ccbb93352 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -63,8 +63,7 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS * A test suite that tests basic Parquet I/O. */ class ParquetIOSuite extends QueryTest with ParquetTest { - private val ctx = sqlContext - import ctx.implicits._ + import testImplicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 9466fd51244a..bf6eaacb4552 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -41,10 +41,7 @@ case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: S class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { import PartitioningUtils._ - - private val ctx = sqlContext - import ctx.implicits._ - import ctx._ + import testImplicits._ val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" @@ -284,7 +281,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempTable("t") { checkAnswer( - sql("SELECT * FROM t"), + ctx.sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -292,7 +289,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } yield Row(i, i.toString, pi, ps)) checkAnswer( - sql("SELECT intField, pi FROM t"), + ctx.sql("SELECT intField, pi FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -300,14 +297,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } yield Row(i, pi)) checkAnswer( - sql("SELECT * FROM t WHERE pi = 1"), + ctx.sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") } yield Row(i, i.toString, 1, ps)) checkAnswer( - sql("SELECT * FROM t WHERE ps = 'foo'"), + ctx.sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -331,7 +328,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempTable("t") { checkAnswer( - sql("SELECT * FROM t"), + ctx.sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -339,7 +336,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } yield Row(i, pi, i.toString, ps)) checkAnswer( - sql("SELECT intField, pi FROM t"), + ctx.sql("SELECT intField, pi FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -347,14 +344,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } yield Row(i, pi)) checkAnswer( - sql("SELECT * FROM t WHERE pi = 1"), + ctx.sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") } yield Row(i, 1, i.toString, ps)) checkAnswer( - sql("SELECT * FROM t WHERE ps = 'foo'"), + ctx.sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -380,7 +377,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempTable("t") { checkAnswer( - sql("SELECT * FROM t"), + ctx.sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) @@ -388,14 +385,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } yield Row(i, i.toString, pi, ps)) checkAnswer( - sql("SELECT * FROM t WHERE pi IS NULL"), + ctx.sql("SELECT * FROM t WHERE pi IS NULL"), for { i <- 1 to 10 ps <- Seq("foo", null.asInstanceOf[String]) } yield Row(i, i.toString, null, ps)) checkAnswer( - sql("SELECT * FROM t WHERE ps IS NULL"), + ctx.sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) @@ -420,7 +417,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempTable("t") { checkAnswer( - sql("SELECT * FROM t"), + ctx.sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -428,7 +425,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } yield Row(i, pi, i.toString, ps)) checkAnswer( - sql("SELECT * FROM t WHERE ps IS NULL"), + ctx.sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -456,7 +453,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempTable("t") { checkAnswer( - sql("SELECT * FROM t"), + ctx.sql("SELECT * FROM t"), (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 6ecea93ee3fc..0695bf9995dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -29,34 +29,33 @@ import org.apache.spark.util.Utils * A test suite that tests various Parquet queries. */ class ParquetQuerySuite extends QueryTest with ParquetTest { - private val ctx = sqlContext - import ctx._ test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) - checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) + checkAnswer(ctx.sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) + checkAnswer( + ctx.sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) } } test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + ctx.sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + ctx.sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -67,7 +66,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { } withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val selfJoin = ctx.sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) @@ -82,7 +81,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("nested data - struct with array field") { val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + checkAnswer(ctx.sql("SELECT _1._2[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) }) } @@ -91,7 +90,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("nested data - array of struct") { val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + checkAnswer(ctx.sql("SELECT _1[0]._2 FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) }) } @@ -99,17 +98,17 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + checkAnswer(ctx.sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } test("SPARK-5309 strings stored using dictionary compression in parquet") { withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + checkAnswer(ctx.sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), (0 until 10).map(i => Row("same", "run_" + i, 100))) - checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + checkAnswer(ctx.sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), List(Row("same", "run_5", 100))) } } @@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) + val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = sqlContext.read.parquet(file.getCanonicalPath) + val df2 = ctx.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length + ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length + ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 6b7200bddfee..8fdbc51bbc77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.sources import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils -abstract class DataSourceTest extends QueryTest with BeforeAndAfter with SharedSQLContext { +abstract class DataSourceTest extends QueryTest with BeforeAndAfter with SQLTestUtils { // We want to test some edge cases. protected implicit lazy val caseInsensitiveContext = { val ctx = new SQLContext(sqlContext.sparkContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index ba68d985b2da..469ea48b4f9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -18,14 +18,20 @@ package org.apache.spark.sql.test import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} /** * A collection of sample data used in SQL tests. */ -private[sql] trait SQLTestData { - protected val _sqlContext: SQLContext - import _sqlContext.implicits._ +private[sql] trait SQLTestData { self => + protected def _sqlContext: SQLContext + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + import internalImplicits._ // All test data should be lazy because the SQLContext is not set up yet @@ -258,8 +264,8 @@ private[sql] trait SQLTestData { } /* ------------------------------ * - | Case classes used in test data | - * ------------------------------ */ + | Case classes used in test data | + * ------------------------------ */ private[sql] case class TestData(key: Int, value: String) private[sql] case class TestData2(a: Int, b: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 36bce1f184db..4c4f82149344 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -23,29 +23,48 @@ import java.util.UUID import scala.util.Try import scala.language.implicitConversions +import org.apache.hadoop.conf.Configuration + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.util.Utils /** - * General helper trait for common functionality in SQL tests. + * Helper trait that should be extended by all SQL test suites involving a + * [[org.apache.spark.sql.SQLContext]]. */ -private[sql] trait SQLTestUtils - extends SparkFunSuite - with AbstractSQLTestUtils - with SharedSQLContext { - +private[sql] trait SQLTestUtils extends AbstractSQLTestUtils with SharedSQLContext { protected final override def _sqlContext = sqlContext } /** - * Abstract helper trait for SQL tests with a pluggable [[SQLContext]]. + * Helper trait that should be extended by all SQL test suites. + * + * This base trait allows subclasses to plugin a custom [[SQLContext]]. It comes with test + * data prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. */ -private[sql] trait AbstractSQLTestUtils { this: SparkFunSuite => +private[sql] trait AbstractSQLTestUtils extends SparkFunSuite with SQLTestData { self => protected def _sqlContext: SQLContext - protected def configuration = _sqlContext.sparkContext.hadoopConfiguration + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + /** + * The Hadoop configuration used by the active [[SQLContext]]. + */ + protected def configuration: Configuration = { + _sqlContext.sparkContext.hadoopConfiguration + } /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -131,12 +150,11 @@ private[sql] trait AbstractSQLTestUtils { this: SparkFunSuite => try f finally _sqlContext.sql(s"USE default") } - /** * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { DataFrame(_sqlContext, plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index dd2f8aacdfd7..2bed857851c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -32,17 +32,35 @@ private[sql] trait SharedSQLContext extends SparkFunSuite with BeforeAndAfterAll * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local * mode with the default test configurations. */ - private var _ctx: TestSQLContext = new TestSQLContext + private var _ctx: TestSQLContext = null /** - * The [[TestSQLContext]] to use for all tests in this suite. + * Initialize the [[TestSQLContext]]. + * This is a no-op if the user explicitly switched to a custom context before this is called. */ - protected def sqlContext: TestSQLContext = _ctx + protected override def beforeAll(): Unit = { + super.beforeAll() + if (_ctx == null) { + _ctx = new TestSQLContext + } + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + super.afterAll() + } /** - * Initialize all test data such that all temp tables are properly registered. + * The [[TestSQLContext]] to use for all tests in this suite. */ - protected final def loadTestData(): Unit = _ctx.loadTestData() + protected def ctx: TestSQLContext = _ctx + protected def sqlContext: TestSQLContext = _ctx /** * Switch to a custom [[TestSQLContext]]. @@ -54,8 +72,8 @@ private[sql] trait SharedSQLContext extends SparkFunSuite with BeforeAndAfterAll protected def switchSQLContext(newContext: () => TestSQLContext): Unit = { if (_ctx != null) { _ctx.sparkContext.stop() - _ctx = newContext() } + _ctx = newContext() } /** @@ -71,12 +89,4 @@ private[sql] trait SharedSQLContext extends SparkFunSuite with BeforeAndAfterAll } } - protected override def afterAll(): Unit = { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = null - } - super.afterAll() - } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 05cbc24ffc83..091cb8d4d5e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -23,17 +23,14 @@ import org.apache.spark.sql.{SQLConf, SQLContext} /** * A special [[SQLContext]] prepared for testing. */ -private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) with SQLTestData { +private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => def this() { this(new SparkContext("local[2]", "test-sql-context", new SparkConf().set("spark.sql.testkey", "true"))) } - // For SQLTestData - protected override val _sqlContext: SQLContext = this - - // Use fewer paritions to speed up testing + // Use fewer partitions to speed up testing protected[sql] override def createSession(): SQLSession = new this.SQLSession() /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ @@ -42,4 +39,13 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) with override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) } } + + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() + } + + object testData extends SQLTestData { + protected override def _sqlContext: SQLContext = self + } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala index 2357477f1bfd..3b12e96be5e9 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql.hive.test -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.AbstractSQLTestUtils /** - * Helper trait analogous to [[org.apache.spark.sql.test.SQLTestUtils]] but for hive tests. + * Helper trait that should be extended by all SQL test suites involving a + * [[org.apache.spark.sql.hive.HiveContext]]. + * + * This is analogous to [[org.apache.spark.sql.test.SQLTestUtils]] but for hive tests. */ -private[spark] trait HiveTestUtils - extends SparkFunSuite - with AbstractSQLTestUtils - with SharedHiveContext { - +private[spark] trait HiveTestUtils extends AbstractSQLTestUtils with SharedHiveContext { protected final override def _sqlContext = hiveContext } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala index efec16cf74da..93e50cf33fc1 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala @@ -34,11 +34,33 @@ private[spark] trait SharedHiveContext extends SparkFunSuite with BeforeAndAfter * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local * mode with the default test configurations. */ - private var _ctx: TestHiveContext = new TestHiveContext + private var _ctx: TestHiveContext = null + + /** + * Initialize the [[TestHiveContext]]. + * This is a no-op if the user explicitly switched to a custom context before this is called. + */ + protected override def beforeAll(): Unit = { + if (_ctx != null) { + _ctx = new TestHiveContext + } + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + super.afterAll() + } /** * The [[TestHiveContext]] to use for all tests in this suite. */ + protected def ctx: TestHiveContext = _ctx protected def hiveContext: TestHiveContext = _ctx /** @@ -68,12 +90,4 @@ private[spark] trait SharedHiveContext extends SparkFunSuite with BeforeAndAfter } } - protected override def afterAll(): Unit = { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = null - } - super.afterAll() - } - } From 55d0b1bd314dcd61a9808b92cf8099edb315cb9b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 10 Aug 2015 15:30:56 -0700 Subject: [PATCH 10/39] Fix Java not serializable exception in tests Tests that use test data used to fail before this commit. This is because the underlying case classes would bring in the entire `SQLTestData` trait into the scope. This no longer happens after we move the case classes outside of the trait. --- .../org/apache/spark/sql/SQLQuerySuite.scala | 1 + .../scala/org/apache/spark/sql/UDFSuite.scala | 1 + .../columnar/InMemoryColumnarQuerySuite.scala | 1 + .../columnar/PartitionBatchPruningSuite.scala | 1 + .../apache/spark/sql/test/SQLTestData.scala | 48 ++++++++++--------- 5 files changed, 29 insertions(+), 23 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 923bceda0981..5128eae99d62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index da7ff4d49823..77ecdcda2441 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SQLTestData._ private case class FunctionResult(f1: String, f2: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 1befda35ace3..69d4baf2ca3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel.MEMORY_ONLY diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 354a7fe98da6..483d3e5b07be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SQLTestData._ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with SQLTestUtils { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 469ea48b4f9a..85d73d8e78a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -32,8 +32,9 @@ private[sql] trait SQLTestData { self => } import internalImplicits._ + import SQLTestData._ - // All test data should be lazy because the SQLContext is not set up yet + // Note: all test data should be lazy because the SQLContext is not set up yet. lazy val testData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( @@ -262,27 +263,28 @@ private[sql] trait SQLTestData { self => salary complexData } +} - /* ------------------------------ * - | Case classes used in test data | - * ------------------------------ */ - - private[sql] case class TestData(key: Int, value: String) - private[sql] case class TestData2(a: Int, b: Int) - private[sql] case class TestData3(a: Int, b: Option[Int]) - private[sql] case class LargeAndSmallInts(a: Int, b: Int) - private[sql] case class DecimalData(a: BigDecimal, b: BigDecimal) - private[sql] case class BinaryData(a: Array[Byte], b: Int) - private[sql] case class UpperCaseData(N: Int, L: String) - private[sql] case class LowerCaseData(n: Int, l: String) - private[sql] case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - private[sql] case class MapData(data: scala.collection.Map[Int, String]) - private[sql] case class StringData(s: String) - private[sql] case class IntField(i: Int) - private[sql] case class NullInts(a: Integer) - private[sql] case class NullStrings(n: Int, s: String) - private[sql] case class TableName(tableName: String) - private[sql] case class Person(id: Int, name: String, age: Int) - private[sql] case class Salary(personId: Int, salary: Double) - private[sql] case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) +/** + * Case classes used in test data. + */ +private[sql] object SQLTestData { + case class TestData(key: Int, value: String) + case class TestData2(a: Int, b: Int) + case class TestData3(a: Int, b: Option[Int]) + case class LargeAndSmallInts(a: Int, b: Int) + case class DecimalData(a: BigDecimal, b: BigDecimal) + case class BinaryData(a: Array[Byte], b: Int) + case class UpperCaseData(N: Int, L: String) + case class LowerCaseData(n: Int, l: String) + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + case class MapData(data: scala.collection.Map[Int, String]) + case class StringData(s: String) + case class IntField(i: Int) + case class NullInts(a: Integer) + case class NullStrings(n: Int, s: String) + case class TableName(tableName: String) + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) } From 4f59beef61c7aceac8fe6400b720e1ef12bc4beb Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 10 Aug 2015 16:33:28 -0700 Subject: [PATCH 11/39] Fix DataSourceTest et al. Test suites that extend DataSourceTest used to have this weird implicit SQLContext that was created in the constructor. This was failing tests because the base SQLContext is not ready until after the first test is run. A minor refactor was required to fix the resulting NPEs. This commit also fixes test suites that need to materialize the test data. These suites were materializing them in the constructor before the SQLContext was ready. --- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../org/apache/spark/sql/QueryTest.scala | 6 -- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../execution/SparkSqlSerializer2Suite.scala | 3 +- .../sql/execution/TungstenSortSuite.scala | 2 + .../sql/execution/joins/OuterJoinSuite.scala | 11 ++- .../sql/execution/joins/SemiJoinSuite.scala | 10 +-- .../sources/CreateTableAsSelectSuite.scala | 46 +++++----- .../spark/sql/sources/DataSourceTest.scala | 9 +- .../spark/sql/sources/FilteredScanSuite.scala | 6 +- .../spark/sql/sources/InsertSuite.scala | 86 +++++++++---------- .../spark/sql/sources/SaveLoadSuite.scala | 20 ++--- .../spark/sql/sources/TableScanSuite.scala | 42 +++++---- .../apache/spark/sql/test/SQLTestUtils.scala | 25 +++++- .../spark/sql/test/SharedSQLContext.scala | 2 +- 17 files changed, 143 insertions(+), 133 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a3d6b270e775..dcf604bb3e9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.test.SQLTestUtils class JoinSuite extends QueryTest with BeforeAndAfterEach with SQLTestUtils { import testImplicits._ - loadTestData() + setupTestData() test("equi-join is hash-join") { val x = testData2.as("x") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 98ba3c99283a..4adcefb7dc4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -71,12 +71,6 @@ class QueryTest extends PlanTest { checkAnswer(df, expectedAnswer.collect()) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } - /** * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5128eae99d62..e97bd11351b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -36,7 +36,7 @@ class MyDialect extends DefaultParserDialect class SQLQuerySuite extends QueryTest with SQLTestUtils { import testImplicits._ - loadTestData() + setupTestData() test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 69d4baf2ca3d..323d9b2c18ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { import testImplicits._ - loadTestData() + setupTestData() test("simple columnar query") { val plan = ctx.executePlan(testData.logicalPlan).executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 1cec8f676823..03769f0fd7cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{execution, Row, SQLConf} class PlannerSuite extends SparkFunSuite with SQLTestUtils { import testImplicits._ - loadTestData() + setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val _ctx = ctx diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 3535f8f184e9..1a0cc2fce49d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -73,6 +73,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with SQLTestUtils { var useSerializer2: Boolean = _ override def beforeAll(): Unit = { + super.beforeAll() numShufflePartitions = ctx.conf.numShufflePartitions useSerializer2 = ctx.conf.useSqlSerializer2 @@ -111,8 +112,6 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with SQLTestUtils { } ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") - - super.beforeAll() } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 9c5a03da4153..7f6651052eba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -30,11 +30,13 @@ import org.apache.spark.sql.types._ class TungstenSortSuite extends SparkPlanTest { override def beforeAll(): Unit = { + super.beforeAll() ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + super.afterAll() } test("sort followed by limit") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 2c27da596bc4..00ace3c3b576 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -24,22 +24,21 @@ import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} class OuterJoinSuite extends SparkPlanTest { - - val left = Seq( + private lazy val left = Seq( (1, 2.0), (2, 1.0), (3, 3.0) ).toDF("a", "b") - val right = Seq( + private lazy val right = Seq( (2, 3.0), (3, 2.0), (4, 1.0) ).toDF("c", "d") - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) + private val leftKeys: List[Expression] = 'a :: Nil + private val rightKeys: List[Expression] = 'c :: Nil + private val condition = Some(LessThan('b, 'd)) test("shuffled hash outer join") { checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 17720c1c5da3..ce97f1297d06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} class SemiJoinSuite extends SparkPlanTest { - val left = Seq( + private lazy val left = Seq( (1, 2.0), (1, 2.0), (2, 1.0), @@ -32,16 +32,16 @@ class SemiJoinSuite extends SparkPlanTest { (3, 3.0) ).toDF("a", "b") - val right = Seq( + private lazy val right = Seq( (2, 3.0), (2, 3.0), (3, 2.0), (4, 1.0) ).toDF("c", "d") - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) + private val leftKeys: List[Expression] = 'a :: Nil + private val rightKeys: List[Expression] = 'c :: Nil + private val condition = Some(LessThan('b, 'd)) test("left semi join hash") { checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 1907e643c85d..282a55864a8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,14 +26,11 @@ import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.util.Utils class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null + private var path: File = null override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") @@ -41,6 +38,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { override def afterAll(): Unit = { caseInsensitiveContext.dropTempTable("jt") + super.afterAll() } after { @@ -48,7 +46,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { } test("CREATE TEMPORARY TABLE AS SELECT") { - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource @@ -59,8 +57,8 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) checkAnswer( - sql("SELECT a, b FROM jsonTable"), - sql("SELECT a, b FROM jt").collect()) + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jt").collect()) caseInsensitiveContext.dropTempTable("jsonTable") } @@ -72,7 +70,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { path.setWritable(false) val e = intercept[IOException] { - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource @@ -81,7 +79,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { |) AS |SELECT a, b FROM jt """.stripMargin) - sql("SELECT a, b FROM jsonTable").collect() + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable").collect() } assert(e.getMessage().contains("Unable to clear output directory")) @@ -89,7 +87,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { } test("create a table, drop it and create another one with the same name") { - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource @@ -100,11 +98,11 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) checkAnswer( - sql("SELECT a, b FROM jsonTable"), - sql("SELECT a, b FROM jt").collect()) + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jt").collect()) val message = intercept[DDLException]{ - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable |USING org.apache.spark.sql.json.DefaultSource @@ -119,7 +117,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") // Overwrite the temporary table. - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource @@ -129,14 +127,14 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { |SELECT a * 4 FROM jt """.stripMargin) checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT a * 4 FROM jt").collect()) + caseInsensitiveContext.sql("SELECT * FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a * 4 FROM jt").collect()) caseInsensitiveContext.dropTempTable("jsonTable") // Explicitly delete the data. if (path.exists()) Utils.deleteRecursively(path) - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource @@ -147,15 +145,15 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT b FROM jt").collect()) + caseInsensitiveContext.sql("SELECT * FROM jsonTable"), + caseInsensitiveContext.sql("SELECT b FROM jt").collect()) caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { val message = intercept[DDLException]{ - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable |USING org.apache.spark.sql.json.DefaultSource @@ -172,7 +170,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { test("a CTAS statement with column definitions is not allowed") { intercept[DDLException]{ - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) |USING org.apache.spark.sql.json.DefaultSource @@ -185,7 +183,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { } test("it is not allowed to write to a table while querying it.") { - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource @@ -196,7 +194,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) val message = intercept[AnalysisException] { - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 8fdbc51bbc77..2da4c06f9a9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -24,11 +24,18 @@ import org.apache.spark.sql.test.SQLTestUtils abstract class DataSourceTest extends QueryTest with BeforeAndAfter with SQLTestUtils { + // We want to test some edge cases. - protected implicit lazy val caseInsensitiveContext = { + protected lazy val caseInsensitiveContext: SQLContext = { val ctx = new SQLContext(sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } + protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { + test(sqlString) { + checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 81b3a0f0c5b3..e367e5558362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -97,10 +97,8 @@ object FiltersPushed { class FilteredScanSuite extends DataSourceTest { - import caseInsensitiveContext.sql - before { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTenFiltered |USING org.apache.spark.sql.sources.FilteredScanSource @@ -237,7 +235,7 @@ class FilteredScanSuite extends DataSourceTest { def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { - val queryExecution = sql(sqlString).queryExecution + val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 0b7c46c482c8..adc23339dc70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -25,18 +25,15 @@ import org.apache.spark.sql.{SaveMode, AnalysisException, Row} import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - private lazy val sparkContext = caseInsensitiveContext.sparkContext + private var path: File = null - var path: File = null - - override def beforeAll: Unit = { + override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") - sql( + caseInsensitiveContext.sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) |USING org.apache.spark.sql.json.DefaultSource @@ -50,45 +47,46 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { caseInsensitiveContext.dropTempTable("jsonTable") caseInsensitiveContext.dropTempTable("jt") Utils.deleteRecursively(path) + super.afterAll() } test("Simple INSERT OVERWRITE a JSONRelation") { - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt """.stripMargin) checkAnswer( - sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) } test("PreInsert casting and renaming") { - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, a * 4 FROM jt """.stripMargin) checkAnswer( - sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 2, s"${i * 4}")) ) - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a * 4 AS A, a * 6 as c FROM jt """.stripMargin) checkAnswer( - sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 4, s"${i * 6}")) ) } test("SELECT clause generating a different number of columns is not allowed.") { val message = intercept[RuntimeException] { - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt """.stripMargin) @@ -100,45 +98,45 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } test("INSERT OVERWRITE a JSONRelation multiple times") { - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt """.stripMargin) checkAnswer( - sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 """.stripMargin) checkAnswer( - sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 """.stripMargin) checkAnswer( - sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a * 10, b FROM jt1 """.stripMargin) checkAnswer( - sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 10, s"str$i")) ) @@ -148,7 +146,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { test("INSERT INTO not supported for JSONRelation for now") { intercept[RuntimeException]{ - sql( + caseInsensitiveContext.sql( s""" |INSERT INTO TABLE jsonTable SELECT a, b FROM jt """.stripMargin) @@ -159,20 +157,20 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") .write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( - sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 5, s"str$i")) ) caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( - sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) } test("it is not allowed to write to a table while querying it.") { val message = intercept[AnalysisException] { - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jsonTable """.stripMargin) @@ -185,50 +183,50 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { test("Caching") { // Cached Query Execution caseInsensitiveContext.cacheTable("jsonTable") - assertCached(sql("SELECT * FROM jsonTable")) + assertCached(caseInsensitiveContext.sql("SELECT * FROM jsonTable")) checkAnswer( - sql("SELECT * FROM jsonTable"), + caseInsensitiveContext.sql("SELECT * FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i"))) - assertCached(sql("SELECT a FROM jsonTable")) + assertCached(caseInsensitiveContext.sql("SELECT a FROM jsonTable")) checkAnswer( - sql("SELECT a FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a FROM jsonTable"), (1 to 10).map(Row(_)).toSeq) - assertCached(sql("SELECT a FROM jsonTable WHERE a < 5")) + assertCached(caseInsensitiveContext.sql("SELECT a FROM jsonTable WHERE a < 5")) checkAnswer( - sql("SELECT a FROM jsonTable WHERE a < 5"), + caseInsensitiveContext.sql("SELECT a FROM jsonTable WHERE a < 5"), (1 to 4).map(Row(_)).toSeq) - assertCached(sql("SELECT a * 2 FROM jsonTable")) + assertCached(caseInsensitiveContext.sql("SELECT a * 2 FROM jsonTable")) checkAnswer( - sql("SELECT a * 2 FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a * 2 FROM jsonTable"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + assertCached(caseInsensitiveContext.sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) checkAnswer( - sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + caseInsensitiveContext.sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Insert overwrite and keep the same schema. - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt """.stripMargin) // jsonTable should be recached. - assertCached(sql("SELECT * FROM jsonTable")) + assertCached(caseInsensitiveContext.sql("SELECT * FROM jsonTable")) // The cached data is the new data. checkAnswer( - sql("SELECT a, b FROM jsonTable"), - sql("SELECT a * 2, b FROM jt").collect()) + caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + caseInsensitiveContext.sql("SELECT a * 2, b FROM jt").collect()) // Verify uncaching caseInsensitiveContext.uncacheTable("jsonTable") - assertCached(sql("SELECT * FROM jsonTable"), 0) + assertCached(caseInsensitiveContext.sql("SELECT * FROM jsonTable"), 0) } test("it's not allowed to insert into a relation that is not an InsertableRelation") { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTen |USING org.apache.spark.sql.sources.SimpleScanSource @@ -239,12 +237,12 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) checkAnswer( - sql("SELECT * FROM oneToTen"), + caseInsensitiveContext.sql("SELECT * FROM oneToTen"), (1 to 10).map(Row(_)).toSeq ) val message = intercept[AnalysisException] { - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index b032515a9d28..c33c81a1a95f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,25 +19,18 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql.{SaveMode, SQLConf, DataFrame} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - +class SaveLoadSuite extends DataSourceTest { private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var originalDefaultSource: String = null - - var path: File = null - - var df: DataFrame = null + private var originalDefaultSource: String = null + private var path: File = null + private var df: DataFrame = null override def beforeAll(): Unit = { + super.beforeAll() originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() @@ -50,6 +43,7 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { override def afterAll(): Unit = { caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + super.afterAll() } after { @@ -69,7 +63,7 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), - sql("SELECT b FROM jsonTable").collect()) + caseInsensitiveContext.sql("SELECT b FROM jsonTable").collect()) } test("save with path and load") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index cfb03ff485b7..31faa72b6232 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -98,8 +98,6 @@ case class AllDataTypesScan( } class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext.sql - private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( s"str_$i", @@ -125,7 +123,7 @@ class TableScanSuite extends DataSourceTest { }.toSeq before { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTen |USING org.apache.spark.sql.sources.SimpleScanSource @@ -137,7 +135,7 @@ class TableScanSuite extends DataSourceTest { |) """.stripMargin) - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE tableWithSchema ( |`string$%Field` stRIng, @@ -232,7 +230,7 @@ class TableScanSuite extends DataSourceTest { assert(expectedSchema == caseInsensitiveContext.table("tableWithSchema").schema) checkAnswer( - sql( + caseInsensitiveContext.sql( """SELECT | `string$%Field`, | cast(binaryField as string), @@ -285,38 +283,38 @@ class TableScanSuite extends DataSourceTest { test("Caching") { // Cached Query Execution caseInsensitiveContext.cacheTable("oneToTen") - assertCached(sql("SELECT * FROM oneToTen")) + assertCached(caseInsensitiveContext.sql("SELECT * FROM oneToTen")) checkAnswer( - sql("SELECT * FROM oneToTen"), + caseInsensitiveContext.sql("SELECT * FROM oneToTen"), (1 to 10).map(Row(_)).toSeq) - assertCached(sql("SELECT i FROM oneToTen")) + assertCached(caseInsensitiveContext.sql("SELECT i FROM oneToTen")) checkAnswer( - sql("SELECT i FROM oneToTen"), + caseInsensitiveContext.sql("SELECT i FROM oneToTen"), (1 to 10).map(Row(_)).toSeq) - assertCached(sql("SELECT i FROM oneToTen WHERE i < 5")) + assertCached(caseInsensitiveContext.sql("SELECT i FROM oneToTen WHERE i < 5")) checkAnswer( - sql("SELECT i FROM oneToTen WHERE i < 5"), + caseInsensitiveContext.sql("SELECT i FROM oneToTen WHERE i < 5"), (1 to 4).map(Row(_)).toSeq) - assertCached(sql("SELECT i * 2 FROM oneToTen")) + assertCached(caseInsensitiveContext.sql("SELECT i * 2 FROM oneToTen")) checkAnswer( - sql("SELECT i * 2 FROM oneToTen"), + caseInsensitiveContext.sql("SELECT i * 2 FROM oneToTen"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + assertCached(caseInsensitiveContext.sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) checkAnswer( - sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + caseInsensitiveContext.sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching caseInsensitiveContext.uncacheTable("oneToTen") - assertCached(sql("SELECT * FROM oneToTen"), 0) + assertCached(caseInsensitiveContext.sql("SELECT * FROM oneToTen"), 0) } test("defaultSource") { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTenDef |USING org.apache.spark.sql.sources @@ -327,7 +325,7 @@ class TableScanSuite extends DataSourceTest { """.stripMargin) checkAnswer( - sql("SELECT * FROM oneToTenDef"), + caseInsensitiveContext.sql("SELECT * FROM oneToTenDef"), (1 to 10).map(Row(_)).toSeq) } @@ -335,7 +333,7 @@ class TableScanSuite extends DataSourceTest { // Make sure we do throw correct exception when users use a relation provider that // only implements the RelationProvier or the SchemaRelationProvider. val schemaNotAllowed = intercept[Exception] { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE relationProvierWithSchema (i int) |USING org.apache.spark.sql.sources.SimpleScanSource @@ -348,7 +346,7 @@ class TableScanSuite extends DataSourceTest { assert(schemaNotAllowed.getMessage.contains("does not allow user-specified schemas")) val schemaNeeded = intercept[Exception] { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE schemaRelationProvierWithoutSchema |USING org.apache.spark.sql.sources.AllDataTypesScanSource @@ -362,7 +360,7 @@ class TableScanSuite extends DataSourceTest { } test("SPARK-5196 schema field with comment") { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE student(name string comment "SN", age int comment "SA", grade int) |USING org.apache.spark.sql.sources.AllDataTypesScanSource @@ -374,7 +372,7 @@ class TableScanSuite extends DataSourceTest { |) """.stripMargin) - val planned = sql("SELECT * FROM student").queryExecution.executedPlan + val planned = caseInsensitiveContext.sql("SELECT * FROM student").queryExecution.executedPlan val comments = planned.schema.fields.map { field => if (field.metadata.contains("comment")) field.metadata.getString("comment") else "NO_COMMENT" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 4c4f82149344..dbe75654e56f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.language.implicitConversions import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} @@ -45,9 +46,16 @@ private[sql] trait SQLTestUtils extends AbstractSQLTestUtils with SharedSQLConte * data prepared in advance as well as all implicit conversions used extensively by dataframes. * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. */ -private[sql] trait AbstractSQLTestUtils extends SparkFunSuite with SQLTestData { self => +private[sql] trait AbstractSQLTestUtils + extends SparkFunSuite + with BeforeAndAfterAll + with SQLTestData { self => + protected def _sqlContext: SQLContext + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + /** * A helper object for importing SQL implicits. * @@ -59,6 +67,21 @@ private[sql] trait AbstractSQLTestUtils extends SparkFunSuite with SQLTestData { protected override def _sqlContext: SQLContext = self._sqlContext } + /** + * Materialize the test data immediately after the [[SQLContext]] is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + /** * The Hadoop configuration used by the active [[SQLContext]]. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 2bed857851c7..e3eb0114fb1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -39,10 +39,10 @@ private[sql] trait SharedSQLContext extends SparkFunSuite with BeforeAndAfterAll * This is a no-op if the user explicitly switched to a custom context before this is called. */ protected override def beforeAll(): Unit = { - super.beforeAll() if (_ctx == null) { _ctx = new TestSQLContext } + super.beforeAll() } /** From 88d4f16f543e65dc709b99034fd169687fd0b2b1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 10 Aug 2015 19:35:22 -0700 Subject: [PATCH 12/39] Fix hive tests to use the same pattern This makes hive tests use the same pattern as SQL tests, i.e. everything inherits HiveTestUtils, and those that want to use implicits can do `import testImplicits._`. --- .../HiveWindowFunctionQuerySuite.scala | 27 +- .../spark/sql/hive/test/TestHiveContext.scala | 7 +- .../sql/hive/test/HiveSparkPlanTest.scala | 6 +- .../sql/hive/test/SharedHiveContext.scala | 2 +- .../spark/sql/hive/CachedTableSuite.scala | 142 +++--- .../spark/sql/hive/ErrorPositionSuite.scala | 10 +- .../hive/HiveDataFrameAnalyticsSuite.scala | 36 +- .../sql/hive/HiveDataFrameJoinSuite.scala | 7 +- .../sql/hive/HiveDataFrameWindowSuite.scala | 26 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 7 +- .../spark/sql/hive/HiveParquetSuite.scala | 30 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 11 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 121 +++-- .../spark/sql/hive/ListTablesSuite.scala | 34 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 232 +++++---- .../spark/sql/hive/MultiDatabaseSuite.scala | 17 +- .../hive/ParquetHiveCompatibilitySuite.scala | 2 - .../spark/sql/hive/QueryPartitionSuite.scala | 28 +- .../spark/sql/hive/SerializationSuite.scala | 5 +- .../spark/sql/hive/StatisticsSuite.scala | 55 ++- .../org/apache/spark/sql/hive/UDFSuite.scala | 5 +- .../execution/AggregationQuerySuite.scala | 11 +- .../execution/BigDataBenchmarkSuite.scala | 40 +- .../hive/execution/HiveComparisonTest.scala | 25 +- .../sql/hive/execution/HiveExplainSuite.scala | 18 +- .../HiveOperatorQueryableSuite.scala | 16 +- .../sql/hive/execution/HivePlanTest.scala | 12 +- .../sql/hive/execution/HiveQuerySuite.scala | 290 +++++------ .../hive/execution/HiveResolutionSuite.scala | 23 +- .../sql/hive/execution/HiveSerDeSuite.scala | 10 +- .../hive/execution/HiveTableScanSuite.scala | 32 +- .../sql/hive/execution/HiveUDFSuite.scala | 154 +++--- .../sql/hive/execution/PruningSuite.scala | 22 +- .../sql/hive/execution/SQLQuerySuite.scala | 463 +++++++++--------- .../execution/ScriptTransformationSuite.scala | 1 - .../hive/orc/OrcHadoopFsRelationSuite.scala | 8 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 49 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 56 +-- .../spark/sql/hive/orc/OrcSourceSuite.scala | 43 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 6 +- .../apache/spark/sql/hive/parquetSuites.scala | 234 +++++---- .../CommitFailureTestRelationSuite.scala | 2 - .../ParquetHadoopFsRelationSuite.scala | 14 +- .../SimpleTextHadoopFsRelationSuite.scala | 6 +- .../sql/sources/hadoopFsRelationSuites.scala | 6 +- 45 files changed, 1142 insertions(+), 1209 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 234ec481e79c..7bd13f437227 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils * files, every `createQueryTest` calls should explicitly set `reset` to `false`. */ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfter { - import ctx._ - + private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault private val testTempDir = Utils.createTempDir() @@ -45,8 +44,8 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte Locale.setDefault(Locale.US) // Create the table used in windowing.q - sql("DROP TABLE IF EXISTS part") - sql( + ctx.sql("DROP TABLE IF EXISTS part") + ctx.sql( """ |CREATE TABLE part( | p_partkey INT, @@ -60,13 +59,13 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte | p_comment STRING) """.stripMargin) val testData1 = ctx.getHiveFile("data/files/part_tiny.txt").getCanonicalPath - sql( + ctx.sql( s""" |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part """.stripMargin) - sql("DROP TABLE IF EXISTS over1k") - sql( + ctx.sql("DROP TABLE IF EXISTS over1k") + ctx.sql( """ |create table over1k( | t tinyint, @@ -84,7 +83,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |fields terminated by '|' """.stripMargin) val testData2 = ctx.getHiveFile("data/files/over1k").getCanonicalPath - sql( + ctx.sql( s""" |LOAD DATA LOCAL INPATH '$testData2' overwrite into table over1k """.stripMargin) @@ -92,11 +91,11 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte // The following settings are used for generating golden files with Hive. // We have to use kryo to correctly let Hive serialize plans with window functions. // This is used to generate golden files. - sql("set hive.plan.serialization.format=kryo") + ctx.sql("set hive.plan.serialization.format=kryo") // Explicitly set fs to local fs. - sql(s"set fs.default.name=file://$testTempDir/") + ctx.sql(s"set fs.default.name=file://$testTempDir/") // Ask Hive to run jobs in-process as a single map and reduce task. - sql("set mapred.job.tracker=local") + ctx.sql("set mapred.job.tracker=local") } override def afterAll() { @@ -775,11 +774,11 @@ class HiveWindowFunctionQueryFileSuite // The following settings are used for generating golden files with Hive. // We have to use kryo to correctly let Hive serialize plans with window functions. // This is used to generate golden files. - // sql("set hive.plan.serialization.format=kryo") + // ctx.sql("set hive.plan.serialization.format=kryo") // Explicitly set fs to local fs. - // sql(s"set fs.default.name=file://$testTempDir/") + // ctx.sql(s"set fs.default.name=file://$testTempDir/") // Ask Hive to run jobs in-process as a single map and reduce task. - // sql("set mapred.job.tracker=local") + // ctx.sql("set mapred.job.tracker=local") } override def afterAll() { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala index 6fc4518fa78f..ad7eb19e17d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala @@ -53,6 +53,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => import HiveContext._ + import TestHiveContext._ def this() { this(new SparkContext( @@ -195,8 +196,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } } - case class TestTable(name: String, commands: (() => Unit)*) - protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { () => new QueryExecution(sql).stringResult(): Unit @@ -454,3 +453,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } } } + +private[hive] object TestHiveContext { + case class TestTable(name: String, commands: (() => Unit)*) +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala index 2582461f2bae..c31ee732670c 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.test import org.apache.spark.sql.execution.AbstractSparkPlanTest -import org.apache.spark.sql.SQLContext /** * Base class for writing tests for individual physical operators in hive. @@ -26,7 +25,4 @@ import org.apache.spark.sql.SQLContext */ private[sql] abstract class HiveSparkPlanTest extends AbstractSparkPlanTest - with SharedHiveContext { - - protected override def _sqlContext: SQLContext = hiveContext -} + with HiveTestUtils diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala index 93e50cf33fc1..fb1b57b83d76 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala @@ -41,7 +41,7 @@ private[spark] trait SharedHiveContext extends SparkFunSuite with BeforeAndAfter * This is a no-op if the user explicitly switched to a custom context before this is called. */ protected override def beforeAll(): Unit = { - if (_ctx != null) { + if (_ctx == null) { _ctx = new TestHiveContext } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 78f3db14ef10..1be12bf9a884 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -21,16 +21,14 @@ import java.io.File import org.apache.spark.sql.{SaveMode, AnalysisException, QueryTest} import org.apache.spark.sql.columnar.InMemoryColumnarTableScan -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx._ +class CachedTableSuite extends QueryTest with HiveTestUtils { def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan + val executedPlan = ctx.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -40,168 +38,168 @@ class CachedTableSuite extends QueryTest with SharedHiveContext { } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache table") { - val preCacheResults = sql("SELECT * FROM src").collect().toSeq + val preCacheResults = ctx.sql("SELECT * FROM src").collect().toSeq - cacheTable("src") - assertCached(sql("SELECT * FROM src")) + ctx.cacheTable("src") + assertCached(ctx.sql("SELECT * FROM src")) checkAnswer( - sql("SELECT * FROM src"), + ctx.sql("SELECT * FROM src"), preCacheResults) - assertCached(sql("SELECT * FROM src s")) + assertCached(ctx.sql("SELECT * FROM src s")) checkAnswer( - sql("SELECT * FROM src s"), + ctx.sql("SELECT * FROM src s"), preCacheResults) - uncacheTable("src") - assertCached(sql("SELECT * FROM src"), 0) + ctx.uncacheTable("src") + assertCached(ctx.sql("SELECT * FROM src"), 0) } test("cache invalidation") { - sql("CREATE TABLE cachedTable(key INT, value STRING)") + ctx.sql("CREATE TABLE cachedTable(key INT, value STRING)") - sql("INSERT INTO TABLE cachedTable SELECT * FROM src") - checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) + ctx.sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + checkAnswer(ctx.sql("SELECT * FROM cachedTable"), ctx.table("src").collect().toSeq) - cacheTable("cachedTable") - checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) + ctx.cacheTable("cachedTable") + checkAnswer(ctx.sql("SELECT * FROM cachedTable"), ctx.table("src").collect().toSeq) - sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + ctx.sql("INSERT INTO TABLE cachedTable SELECT * FROM src") checkAnswer( - sql("SELECT * FROM cachedTable"), - table("src").collect().toSeq ++ table("src").collect().toSeq) + ctx.sql("SELECT * FROM cachedTable"), + ctx.table("src").collect().toSeq ++ ctx.table("src").collect().toSeq) - sql("DROP TABLE cachedTable") + ctx.sql("DROP TABLE cachedTable") } test("Drop cached table") { - sql("CREATE TABLE cachedTableTest(a INT)") - cacheTable("cachedTableTest") - sql("SELECT * FROM cachedTableTest").collect() - sql("DROP TABLE cachedTableTest") + ctx.sql("CREATE TABLE cachedTableTest(a INT)") + ctx.cacheTable("cachedTableTest") + ctx.sql("SELECT * FROM cachedTableTest").collect() + ctx.sql("DROP TABLE cachedTableTest") intercept[AnalysisException] { - sql("SELECT * FROM cachedTableTest").collect() + ctx.sql("SELECT * FROM cachedTableTest").collect() } } test("DROP nonexistant table") { - sql("DROP TABLE IF EXISTS nonexistantTable") + ctx.sql("DROP TABLE IF EXISTS nonexistantTable") } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - uncacheTable("src") + ctx.uncacheTable("src") } } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - sql("CACHE TABLE src") - assertCached(table("src")) - assert(isCached("src"), "Table 'src' should be cached") + ctx.sql("CACHE TABLE src") + assertCached(ctx.table("src")) + assert(ctx.isCached("src"), "Table 'src' should be cached") - sql("UNCACHE TABLE src") - assertCached(table("src"), 0) - assert(!isCached("src"), "Table 'src' should not be cached") + ctx.sql("UNCACHE TABLE src") + assertCached(ctx.table("src"), 0) + assert(!ctx.isCached("src"), "Table 'src' should not be cached") } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { - sql("CACHE TABLE testCacheTable AS SELECT * FROM src") - assertCached(table("testCacheTable")) + ctx.sql("CACHE TABLE testCacheTable AS SELECT * FROM src") + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } test("CACHE TABLE tableName AS SELECT ...") { - sql("CACHE TABLE testCacheTable AS SELECT key FROM src LIMIT 10") - assertCached(table("testCacheTable")) + ctx.sql("CACHE TABLE testCacheTable AS SELECT key FROM src LIMIT 10") + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } test("CACHE LAZY TABLE tableName") { - sql("CACHE LAZY TABLE src") - assertCached(table("src")) + ctx.sql("CACHE LAZY TABLE src") + assertCached(ctx.table("src")) val rddId = rddIdOf("src") assert( !isMaterialized(rddId), "Lazily cached in-memory table shouldn't be materialized eagerly") - sql("SELECT COUNT(*) FROM src").collect() + ctx.sql("SELECT COUNT(*) FROM src").collect() assert( isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - uncacheTable("src") + ctx.uncacheTable("src") assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } test("CACHE TABLE with Hive UDF") { - sql("CACHE TABLE udfTest AS SELECT * FROM src WHERE floor(key) = 1") - assertCached(table("udfTest")) - uncacheTable("udfTest") + ctx.sql("CACHE TABLE udfTest AS SELECT * FROM src WHERE floor(key) = 1") + assertCached(ctx.table("udfTest")) + ctx.uncacheTable("udfTest") } test("REFRESH TABLE also needs to recache the data (data source tables)") { val tempPath: File = Utils.createTempDir() tempPath.delete() - table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) - sql("DROP TABLE IF EXISTS refreshTable") - createExternalTable("refreshTable", tempPath.toString, "parquet") + ctx.table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) + ctx.sql("DROP TABLE IF EXISTS refreshTable") + ctx.createExternalTable("refreshTable", tempPath.toString, "parquet") checkAnswer( - table("refreshTable"), - table("src").collect()) + ctx.table("refreshTable"), + ctx.table("src").collect()) // Cache the table. - sql("CACHE TABLE refreshTable") - assertCached(table("refreshTable")) + ctx.sql("CACHE TABLE refreshTable") + assertCached(ctx.table("refreshTable")) // Append new data. - table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) + ctx.table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) // We are still using the old data. - assertCached(table("refreshTable")) + assertCached(ctx.table("refreshTable")) checkAnswer( - table("refreshTable"), - table("src").collect()) + ctx.table("refreshTable"), + ctx.table("src").collect()) // Refresh the table. - sql("REFRESH TABLE refreshTable") + ctx.sql("REFRESH TABLE refreshTable") // We are using the new data. - assertCached(table("refreshTable")) + assertCached(ctx.table("refreshTable")) checkAnswer( - table("refreshTable"), - table("src").unionAll(table("src")).collect()) + ctx.table("refreshTable"), + ctx.table("src").unionAll(ctx.table("src")).collect()) // Drop the table and create it again. - sql("DROP TABLE refreshTable") - createExternalTable("refreshTable", tempPath.toString, "parquet") + ctx.sql("DROP TABLE refreshTable") + ctx.createExternalTable("refreshTable", tempPath.toString, "parquet") // It is not cached. - assert(!isCached("refreshTable"), "refreshTable should not be cached.") + assert(!ctx.isCached("refreshTable"), "refreshTable should not be cached.") // Refresh the table. REFRESH TABLE command should not make a uncached // table cached. - sql("REFRESH TABLE refreshTable") + ctx.sql("REFRESH TABLE refreshTable") checkAnswer( - table("refreshTable"), - table("src").unionAll(table("src")).collect()) + ctx.table("refreshTable"), + ctx.table("src").unionAll(ctx.table("src")).collect()) // It is not cached. - assert(!isCached("refreshTable"), "refreshTable should not be cached.") + assert(!ctx.isCached("refreshTable"), "refreshTable should not be cached.") - sql("DROP TABLE refreshTable") + ctx.sql("DROP TABLE refreshTable") Utils.deleteRecursively(tempPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index c0aa6c281d0c..77eacc66e5e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,14 +22,12 @@ import scala.util.Try import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.sql.{AnalysisException, QueryTest} -class ErrorPositionSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ +class ErrorPositionSuite extends QueryTest with BeforeAndAfter with HiveTestUtils { + import testImplicits._ before { Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") @@ -124,7 +122,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter with SharedHiveCo test(name) { val error = intercept[AnalysisException] { - quietly(sql(query)) + quietly(ctx.sql(query)) } assert(!error.getMessage.contains("Seq(")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 2f0d54aba112..0d94bb3cbf78 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -19,48 +19,48 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ +class HiveDataFrameAnalyticsSuite extends QueryTest with HiveTestUtils { + import testImplicits._ - private var testData: DataFrame = _ + private var _testData: DataFrame = _ - override def beforeAll() { - testData = Seq((1, 2), (2, 4)).toDF("a", "b") - registerDataFrameAsTable(testData, "mytable") + override def beforeAll(): Unit = { + super.beforeAll() + _testData = Seq((1, 2), (2, 4)).toDF("a", "b") + ctx.registerDataFrameAsTable(_testData, "mytable") } override def afterAll(): Unit = { - dropTempTable("mytable") + ctx.dropTempTable("mytable") + super.afterAll() } test("rollup") { checkAnswer( - testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), - sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() + _testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), + ctx.sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() ) checkAnswer( - testData.rollup("a", "b").agg(sum("b")), - sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() + _testData.rollup("a", "b").agg(sum("b")), + ctx.sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() ) } test("cube") { checkAnswer( - testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), - sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() + _testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + ctx.sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() ) checkAnswer( - testData.cube("a", "b").agg(sum("b")), - sql("select a, b, sum(b) from mytable group by a, b with cube").collect() + _testData.cube("a", "b").agg(sum("b")), + ctx.sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index 22b6e1d25804..dbed3bce6d55 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils -class HiveDataFrameJoinSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ +class HiveDataFrameJoinSuite extends QueryTest with HiveTestUtils { + import testImplicits._ // We should move this into SQL package if we make case sensitivity configurable in SQL. test("join - self join auto resolve ambiguity with case insensitivity") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 14f0c0252013..c15bf1cd95d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -20,12 +20,10 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils -class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ +class HiveDataFrameWindowSuite extends QueryTest with HiveTestUtils { + import testImplicits._ test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") @@ -56,7 +54,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( + ctx.sql( """SELECT | lead(value) OVER (PARTITION BY key ORDER BY value) | FROM window_table""".stripMargin).collect()) @@ -69,7 +67,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( + ctx.sql( """SELECT | lag(value) OVER (PARTITION BY key ORDER BY value) | FROM window_table""".stripMargin).collect()) @@ -82,7 +80,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - sql( + ctx.sql( """SELECT | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) | FROM window_table""".stripMargin).collect()) @@ -95,7 +93,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - sql( + ctx.sql( """SELECT | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) | FROM window_table""".stripMargin).collect()) @@ -118,7 +116,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { rank().over(Window.partitionBy("value").orderBy("key")), cumeDist().over(Window.partitionBy("value").orderBy("key")), percentRank().over(Window.partitionBy("value").orderBy("key"))), - sql( + ctx.sql( s"""SELECT |key, |max(key) over (partition by value order by key), @@ -141,7 +139,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - sql( + ctx.sql( """SELECT | avg(key) OVER | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) @@ -154,7 +152,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - sql( + ctx.sql( """SELECT | avg(key) OVER | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) @@ -172,7 +170,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { last("value").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), - sql( + ctx.sql( """SELECT | key, | last_value(value) OVER @@ -201,7 +199,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") ), - sql( + ctx.sql( """SELECT | key, | last_value(value) OVER diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 02c5a4b4c92e..18e52b739757 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.hive import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends SparkFunSuite with SharedHiveContext with Logging { - private val ctx = hiveContext +class HiveMetastoreCatalogSuite extends SparkFunSuite with HiveTestUtils with Logging { + import testImplicits._ test("struct field should accept underscore in sub-column name") { val metastr = "struct" @@ -39,7 +39,6 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with SharedHiveContext wit } test("duplicated metastore relations") { - import ctx.implicits._ val df = ctx.sql("SELECT * FROM src") logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 87ce5e815e91..73486836c4f6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -23,40 +23,38 @@ import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends QueryTest with HiveParquetTest { - private val ctx = hiveContext - import ctx._ test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { val expected = (1 to 4).map(i => Row(i.toString)) - checkAnswer(sql("SELECT upper FROM cases"), expected) - checkAnswer(sql("SELECT LOWER FROM cases"), expected) + checkAnswer(ctx.sql("SELECT upper FROM cases"), expected) + checkAnswer(ctx.sql("SELECT LOWER FROM cases"), expected) } } test("SELECT on Parquet table") { val data = (1 to 4).map(i => (i, s"val_$i")) withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) + checkAnswer(ctx.sql("SELECT * FROM t"), data.map(Row.fromTuple)) } } test("Simple column projection + filter on Parquet table") { withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { checkAnswer( - sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), + ctx.sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), Seq(Row(true, "val_2"), Row(true, "val_4"))) } } test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => - sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).registerTempTable("p") + ctx.sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( - sql("SELECT * FROM src ORDER BY key"), - sql("SELECT * from p ORDER BY key").collect().toSeq) + ctx.sql("SELECT * FROM src ORDER BY key"), + ctx.sql("SELECT * from p ORDER BY key").collect().toSeq) } } } @@ -64,14 +62,14 @@ class HiveParquetSuite extends QueryTest with HiveParquetTest { test("INSERT OVERWRITE TABLE Parquet table") { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => - sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - read.parquet(file.getCanonicalPath).registerTempTable("p") + ctx.sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) + ctx.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) + ctx.sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + ctx.sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + ctx.sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + checkAnswer(ctx.sql("SELECT * FROM p"), ctx.sql("SELECT * FROM t").collect().toSeq) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 0aae289dc94c..4dadb3902be6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.hive.test.{TestHiveContext, SharedHiveContext} +import org.apache.spark.sql.hive.test.{TestHiveContext, HiveTestUtils} import org.apache.spark.util.{ResetSystemProperties, Utils} /** @@ -39,10 +39,7 @@ class HiveSparkSubmitSuite with Matchers with ResetSystemProperties with Timeouts - with SharedHiveContext { - - private val ctx = hiveContext - import ctx._ + with HiveTestUtils { // TODO: rewrite these or mark them as slow tests to be run sparingly @@ -50,8 +47,8 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = ctx.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() + val jar4 = ctx.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d02f2ac8beb5..5e440411f89c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -24,39 +24,36 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.execution.QueryExecutionException -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.test.SQLTestData.TestData import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -case class TestData(key: Int, value: String) - case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with HiveTestUtils { + import testImplicits._ - private val _testData = sparkContext.parallelize( + private val _testData = ctx.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { // Since every we are doing tests for DDL statements, // it is better to reset before every test. - reset() + ctx.reset() // Register the testData, which will be used in every test. _testData.registerTempTable("testData") } test("insertInto() HiveTable") { - sql("CREATE TABLE createAndInsertTest (key int, value string)") + ctx.sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. _testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has also been updated. checkAnswer( - sql("SELECT * FROM createAndInsertTest"), + ctx.sql("SELECT * FROM createAndInsertTest"), _testData.collect().toSeq ) @@ -65,7 +62,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with Shared // Make sure the table has been updated. checkAnswer( - sql("SELECT * FROM createAndInsertTest"), + ctx.sql("SELECT * FROM createAndInsertTest"), _testData.toDF().collect().toSeq ++ _testData.toDF().collect().toSeq ) @@ -74,71 +71,71 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with Shared // Make sure the registered table has also been updated. checkAnswer( - sql("SELECT * FROM createAndInsertTest"), + ctx.sql("SELECT * FROM createAndInsertTest"), _testData.collect().toSeq ) } test("Double create fails when allowExisting = false") { - sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + ctx.sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") val message = intercept[QueryExecutionException] { - sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + ctx.sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") }.getMessage } test("Double create does not fail when allowExisting = true") { - sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") + ctx.sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + ctx.sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") } test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) - val rowRDD = sparkContext.parallelize( + val rowRDD = ctx.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = createDataFrame(rowRDD, schema) + val df = ctx.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") - sql("CREATE TABLE hiveTableWithMapValue(m MAP )") - sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") + ctx.sql("CREATE TABLE hiveTableWithMapValue(m MAP )") + ctx.sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") checkAnswer( - sql("SELECT * FROM hiveTableWithMapValue"), + ctx.sql("SELECT * FROM hiveTableWithMapValue"), rowRDD.collect().toSeq ) - sql("DROP TABLE hiveTableWithMapValue") + ctx.sql("DROP TABLE hiveTableWithMapValue") } test("SPARK-4203:random partition directory order") { - sql("CREATE TABLE tmp_table (key int, value string)") + ctx.sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Utils.createTempDir() val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - sql( + ctx.sql( s""" |CREATE TABLE table_with_partition(c1 string) |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) |location '${tmpDir.toURI.toString}' """.stripMargin) - sql( + ctx.sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (p1='a',p2='b',p3='c',p4='c',p5='1') |SELECT 'blarr' FROM tmp_table """.stripMargin) - sql( + ctx.sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (p1='a',p2='b',p3='c',p4='c',p5='2') |SELECT 'blarr' FROM tmp_table """.stripMargin) - sql( + ctx.sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (p1='a',p2='b',p3='c',p4='c',p5='3') |SELECT 'blarr' FROM tmp_table """.stripMargin) - sql( + ctx.sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (p1='a',p2='b',p3='c',p4='c',p5='4') @@ -160,104 +157,104 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with Shared "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) - sql("DROP TABLE table_with_partition") - sql("DROP TABLE tmp_table") + ctx.sql("DROP TABLE table_with_partition") + ctx.sql("DROP TABLE tmp_table") } test("Insert ArrayType.containsNull == false") { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) - val rowRDD = sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = createDataFrame(rowRDD, schema) + val rowRDD = ctx.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) + val df = ctx.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") - sql("CREATE TABLE hiveTableWithArrayValue(a Array )") - sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") + ctx.sql("CREATE TABLE hiveTableWithArrayValue(a Array )") + ctx.sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") checkAnswer( - sql("SELECT * FROM hiveTableWithArrayValue"), + ctx.sql("SELECT * FROM hiveTableWithArrayValue"), rowRDD.collect().toSeq) - sql("DROP TABLE hiveTableWithArrayValue") + ctx.sql("DROP TABLE hiveTableWithArrayValue") } test("Insert MapType.valueContainsNull == false") { val schema = StructType(Seq( StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) - val rowRDD = sparkContext.parallelize( + val rowRDD = ctx.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = createDataFrame(rowRDD, schema) + val df = ctx.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") - sql("CREATE TABLE hiveTableWithMapValue(m Map )") - sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") + ctx.sql("CREATE TABLE hiveTableWithMapValue(m Map )") + ctx.sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") checkAnswer( - sql("SELECT * FROM hiveTableWithMapValue"), + ctx.sql("SELECT * FROM hiveTableWithMapValue"), rowRDD.collect().toSeq) - sql("DROP TABLE hiveTableWithMapValue") + ctx.sql("DROP TABLE hiveTableWithMapValue") } test("Insert StructType.fields.exists(_.nullable == false)") { val schema = StructType(Seq( StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) - val rowRDD = sparkContext.parallelize( + val rowRDD = ctx.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = createDataFrame(rowRDD, schema) + val df = ctx.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") - sql("CREATE TABLE hiveTableWithStructValue(s Struct )") - sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") + ctx.sql("CREATE TABLE hiveTableWithStructValue(s Struct )") + ctx.sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") checkAnswer( - sql("SELECT * FROM hiveTableWithStructValue"), + ctx.sql("SELECT * FROM hiveTableWithStructValue"), rowRDD.collect().toSeq) - sql("DROP TABLE hiveTableWithStructValue") + ctx.sql("DROP TABLE hiveTableWithStructValue") } test("SPARK-5498:partition schema does not match table schema") { - val testData = sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") - val testDatawithNull = sparkContext.parallelize( + val testDatawithNull = ctx.sparkContext.parallelize( (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() - sql( + ctx.sql( s""" |CREATE TABLE table_with_partition(key int,value string) |PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' """.stripMargin) - sql( + ctx.sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (ds='1') SELECT key,value FROM testData """.stripMargin) // test schema the same between partition and table - sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") - checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + ctx.sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(ctx.sql("select key,value from table_with_partition where ds='1' "), testData.collect().toSeq ) // test difference type of field - sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") - checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + ctx.sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(ctx.sql("select key,value from table_with_partition where ds='1' "), testData.collect().toSeq ) // add column to table - sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") - checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), + ctx.sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") + checkAnswer(ctx.sql("select key,value,key1 from table_with_partition where ds='1' "), testDatawithNull.collect().toSeq ) // change column name to table - sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") - checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), + ctx.sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") + checkAnswer(ctx.sql("select keynew,value from table_with_partition where ds='1' "), testData.collect().toSeq ) - sql("DROP TABLE table_with_partition") + ctx.sql("DROP TABLE table_with_partition") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index b6aa12a5f6e3..4aa97be78ac8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -18,35 +18,33 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils -class ListTablesSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ +class ListTablesSuite extends QueryTest with HiveTestUtils { + import testImplicits._ val df = - sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") + ctx.sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. - catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) - catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) - sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") - sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") - sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") + ctx.catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) + ctx.catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) + ctx.sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") + ctx.sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") + ctx.sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") } override def afterAll(): Unit = { - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) - sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") - sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") - sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) + ctx.sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") + ctx.sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") + ctx.sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") } test("get all tables of current database") { - Seq(tables(), sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { case allTables => // We are using default DB. checkAnswer( @@ -61,7 +59,7 @@ class ListTablesSuite extends QueryTest with SharedHiveContext { } test("getting all tables with a database name") { - Seq(tables("listtablessuiteDb"), sql("SHOW TABLes in listTablesSuitedb")).foreach { + Seq(ctx.tables("listtablessuiteDb"), ctx.sql("SHOW TABLes in listTablesSuitedb")).foreach { case allTables => checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 5a1aa079c485..a722f31dab07 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -41,9 +41,7 @@ class MetastoreDataSourcesSuite with HiveTestUtils with Logging { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ + import testImplicits._ var jsonFilePath: String = _ @@ -53,7 +51,7 @@ class MetastoreDataSourcesSuite test("persistent JSON table") { withTable("jsonTable") { - sql( + ctx.sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -62,14 +60,14 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - sql("SELECT * FROM jsonTable"), - read.json(jsonFilePath).collect().toSeq) + ctx.sql("SELECT * FROM jsonTable"), + ctx.read.json(jsonFilePath).collect().toSeq) } } test("persistent JSON table with a user specified schema") { withTable("jsonTable") { - sql( + ctx.sql( s"""CREATE TABLE jsonTable ( |a string, |b String, @@ -82,10 +80,10 @@ class MetastoreDataSourcesSuite """.stripMargin) withTempTable("expectedJsonTable") { - read.json(jsonFilePath).registerTempTable("expectedJsonTable") + ctx.read.json(jsonFilePath).registerTempTable("expectedJsonTable") checkAnswer( - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) + ctx.sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), + ctx.sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) } } } @@ -94,7 +92,7 @@ class MetastoreDataSourcesSuite withTable("jsonTable") { // This works because JSON objects are self-describing and JSONRelation can get needed // field values based on field names. - sql( + ctx.sql( s"""CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -109,20 +107,20 @@ class MetastoreDataSourcesSuite StructField("", innerStruct, true), StructField("b", StringType, true))) - assert(expectedSchema === table("jsonTable").schema) + assert(expectedSchema === ctx.table("jsonTable").schema) withTempTable("expectedJsonTable") { - read.json(jsonFilePath).registerTempTable("expectedJsonTable") + ctx.read.json(jsonFilePath).registerTempTable("expectedJsonTable") checkAnswer( - sql("SELECT b, ``.`=` FROM jsonTable"), - sql("SELECT b, ``.`=` FROM expectedJsonTable")) + ctx.sql("SELECT b, ``.`=` FROM jsonTable"), + ctx.sql("SELECT b, ``.`=` FROM expectedJsonTable")) } } } test("resolve shortened provider names") { withTable("jsonTable") { - sql( + ctx.sql( s""" |CREATE TABLE jsonTable |USING org.apache.spark.sql.json @@ -132,14 +130,14 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - sql("SELECT * FROM jsonTable"), - read.json(jsonFilePath).collect().toSeq) + ctx.sql("SELECT * FROM jsonTable"), + ctx.read.json(jsonFilePath).collect().toSeq) } } test("drop table") { withTable("jsonTable") { - sql( + ctx.sql( s""" |CREATE TABLE jsonTable |USING org.apache.spark.sql.json @@ -149,13 +147,13 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - sql("SELECT * FROM jsonTable"), - read.json(jsonFilePath)) + ctx.sql("SELECT * FROM jsonTable"), + ctx.read.json(jsonFilePath)) - sql("DROP TABLE jsonTable") + ctx.sql("DROP TABLE jsonTable") intercept[Exception] { - sql("SELECT * FROM jsonTable").collect() + ctx.sql("SELECT * FROM jsonTable").collect() } assert( @@ -170,7 +168,7 @@ class MetastoreDataSourcesSuite withTable("jsonTable") { (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) - sql( + ctx.sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json |OPTIONS ( @@ -179,7 +177,7 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - sql("SELECT * FROM jsonTable"), + ctx.sql("SELECT * FROM jsonTable"), Row("a", "b")) Utils.deleteRecursively(tempDir) @@ -188,14 +186,14 @@ class MetastoreDataSourcesSuite // Schema is cached so the new column does not show. The updated values in existing columns // will show. checkAnswer( - sql("SELECT * FROM jsonTable"), + ctx.sql("SELECT * FROM jsonTable"), Row("a1", "b1")) - sql("REFRESH TABLE jsonTable") + ctx.sql("REFRESH TABLE jsonTable") // Check that the refresh worked checkAnswer( - sql("SELECT * FROM jsonTable"), + ctx.sql("SELECT * FROM jsonTable"), Row("a1", "b1", "c1")) } } @@ -206,7 +204,7 @@ class MetastoreDataSourcesSuite (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) withTable("jsonTable") { - sql( + ctx.sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json |OPTIONS ( @@ -215,15 +213,15 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - sql("SELECT * FROM jsonTable"), + ctx.sql("SELECT * FROM jsonTable"), Row("a", "b")) Utils.deleteRecursively(tempDir) (("a", "b", "c") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) - sql("DROP TABLE jsonTable") + ctx.sql("DROP TABLE jsonTable") - sql( + ctx.sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json |OPTIONS ( @@ -233,7 +231,7 @@ class MetastoreDataSourcesSuite // New table should reflect new schema. checkAnswer( - sql("SELECT * FROM jsonTable"), + ctx.sql("SELECT * FROM jsonTable"), Row("a", "b", "c")) } } @@ -241,7 +239,7 @@ class MetastoreDataSourcesSuite test("invalidate cache and reload") { withTable("jsonTable") { - sql( + ctx.sql( s"""CREATE TABLE jsonTable (`c_!@(3)` int) |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -250,23 +248,23 @@ class MetastoreDataSourcesSuite """.stripMargin) withTempTable("expectedJsonTable") { - read.json(jsonFilePath).registerTempTable("expectedJsonTable") + ctx.read.json(jsonFilePath).registerTempTable("expectedJsonTable") checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + ctx.sql("SELECT * FROM jsonTable"), + ctx.sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) // Discard the cached relation. - invalidateTable("jsonTable") + ctx.invalidateTable("jsonTable") checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + ctx.sql("SELECT * FROM jsonTable"), + ctx.sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - invalidateTable("jsonTable") + ctx.invalidateTable("jsonTable") val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) - assert(expectedSchema === table("jsonTable").schema) + assert(expectedSchema === ctx.table("jsonTable").schema) } } } @@ -274,7 +272,7 @@ class MetastoreDataSourcesSuite test("CTAS") { withTempPath { tempPath => withTable("jsonTable", "ctasJsonTable") { - sql( + ctx.sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -282,7 +280,7 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - sql( + ctx.sql( s"""CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -291,11 +289,11 @@ class MetastoreDataSourcesSuite |SELECT * FROM jsonTable """.stripMargin) - assert(table("ctasJsonTable").schema === table("jsonTable").schema) + assert(ctx.table("ctasJsonTable").schema === ctx.table("jsonTable").schema) checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + ctx.sql("SELECT * FROM ctasJsonTable"), + ctx.sql("SELECT * FROM jsonTable").collect()) } } } @@ -305,7 +303,7 @@ class MetastoreDataSourcesSuite val tempPath = path.getCanonicalPath withTable("jsonTable", "ctasJsonTable") { - sql( + ctx.sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -313,7 +311,7 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - sql( + ctx.sql( s"""CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -324,7 +322,7 @@ class MetastoreDataSourcesSuite // Create the table again should trigger a AnalysisException. val message = intercept[AnalysisException] { - sql( + ctx.sql( s"""CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -341,7 +339,7 @@ class MetastoreDataSourcesSuite // The following statement should be fine if it has IF NOT EXISTS. // It tries to create a table ctasJsonTable with a new schema. // The actual table's schema and data should not be changed. - sql( + ctx.sql( s"""CREATE TABLE IF NOT EXISTS ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -351,21 +349,21 @@ class MetastoreDataSourcesSuite """.stripMargin) // Discard the cached relation. - invalidateTable("ctasJsonTable") + ctx.invalidateTable("ctasJsonTable") // Schema should not be changed. - assert(table("ctasJsonTable").schema === table("jsonTable").schema) + assert(ctx.table("ctasJsonTable").schema === ctx.table("jsonTable").schema) // Table data should not be changed. checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + ctx.sql("SELECT * FROM ctasJsonTable"), + ctx.sql("SELECT * FROM jsonTable").collect()) } } } test("CTAS a managed table") { withTable("jsonTable", "ctasJsonTable", "loadedTable") { - sql( + ctx.sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -373,13 +371,13 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val expectedPath = ctx.catalog.hiveDefaultTableFilePath("ctasJsonTable") val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + val fs = filesystemPath.getFileSystem(ctx.sparkContext.hadoopConfiguration) if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) // It is a managed table when we do not specify the location. - sql( + ctx.sql( s"""CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |AS @@ -388,7 +386,7 @@ class MetastoreDataSourcesSuite assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - sql( + ctx.sql( s"""CREATE TABLE loadedTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -396,20 +394,20 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - assert(table("ctasJsonTable").schema === table("loadedTable").schema) + assert(ctx.table("ctasJsonTable").schema === ctx.table("loadedTable").schema) checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM loadedTable")) + ctx.sql("SELECT * FROM ctasJsonTable"), + ctx.sql("SELECT * FROM loadedTable")) - sql("DROP TABLE ctasJsonTable") + ctx.sql("DROP TABLE ctasJsonTable") assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") } } test("SPARK-5286 Fail to drop an invalid table when using the data source API") { withTable("jsonTable") { - sql( + ctx.sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -417,7 +415,7 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - sql("DROP TABLE jsonTable").collect().foreach(i => logInfo(i.toString)) + ctx.sql("DROP TABLE jsonTable").collect().foreach(i => logInfo(i.toString)) } } @@ -432,21 +430,21 @@ class MetastoreDataSourcesSuite .saveAsTable("savedJsonTable") checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + ctx.sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), (1 to 4).map(i => Row(i, s"str$i"))) checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + ctx.sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), (6 to 10).map(i => Row(i, s"str$i"))) - invalidateTable("savedJsonTable") + ctx.invalidateTable("savedJsonTable") checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + ctx.sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), (1 to 4).map(i => Row(i, s"str$i"))) checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + ctx.sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), (6 to 10).map(i => Row(i, s"str$i"))) } } @@ -462,7 +460,7 @@ class MetastoreDataSourcesSuite // Save the df as a managed table (by not specifying the path). df.write.saveAsTable("savedJsonTable") - checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + checkAnswer(ctx.sql("SELECT * FROM savedJsonTable"), df) // Right now, we cannot append to an existing JSON table. intercept[RuntimeException] { @@ -471,17 +469,17 @@ class MetastoreDataSourcesSuite // We can overwrite it. df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") - checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + checkAnswer(ctx.sql("SELECT * FROM savedJsonTable"), df) // When the save mode is Ignore, we will do nothing when the table already exists. df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") - assert(df.schema === table("savedJsonTable").schema) - checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + assert(df.schema === ctx.table("savedJsonTable").schema) + checkAnswer(ctx.sql("SELECT * FROM savedJsonTable"), df) // Drop table will also delete the data. - sql("DROP TABLE savedJsonTable") + ctx.sql("DROP TABLE savedJsonTable") intercept[InvalidInputException] { - read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + ctx.read.json(ctx.catalog.hiveDefaultTableFilePath("savedJsonTable")) } } @@ -493,12 +491,12 @@ class MetastoreDataSourcesSuite .option("path", tempPath.toString) .saveAsTable("savedJsonTable") - checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + checkAnswer(ctx.sql("SELECT * FROM savedJsonTable"), df) } // Data should not be deleted after we drop the table. - sql("DROP TABLE savedJsonTable") - checkAnswer(read.json(tempPath.toString), df) + ctx.sql("DROP TABLE savedJsonTable") + checkAnswer(ctx.read.json(tempPath.toString), df) } } } @@ -506,7 +504,7 @@ class MetastoreDataSourcesSuite test("create external table") { withTempPath { tempPath => withTable("savedJsonTable", "createdJsonTable") { - val df = read.json(sparkContext.parallelize((1 to 10).map { i => + val df = ctx.read.json(ctx.sparkContext.parallelize((1 to 10).map { i => s"""{ "a": $i, "b": "str$i" }""" })) @@ -519,39 +517,39 @@ class MetastoreDataSourcesSuite } withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { - createExternalTable("createdJsonTable", tempPath.toString) - assert(table("createdJsonTable").schema === df.schema) - checkAnswer(sql("SELECT * FROM createdJsonTable"), df) + ctx.createExternalTable("createdJsonTable", tempPath.toString) + assert(ctx.table("createdJsonTable").schema === df.schema) + checkAnswer(ctx.sql("SELECT * FROM createdJsonTable"), df) assert( intercept[AnalysisException] { - createExternalTable("createdJsonTable", jsonFilePath.toString) + ctx.createExternalTable("createdJsonTable", jsonFilePath.toString) }.getMessage.contains("Table createdJsonTable already exists."), "We should complain that createdJsonTable already exists") } // Data should not be deleted. - sql("DROP TABLE createdJsonTable") - checkAnswer(read.json(tempPath.toString), df) + ctx.sql("DROP TABLE createdJsonTable") + checkAnswer(ctx.read.json(tempPath.toString), df) // Try to specify the schema. withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { val schema = StructType(StructField("b", StringType, true) :: Nil) - createExternalTable( + ctx.createExternalTable( "createdJsonTable", "org.apache.spark.sql.json", schema, Map("path" -> tempPath.toString)) checkAnswer( - sql("SELECT * FROM createdJsonTable"), - sql("SELECT b FROM savedJsonTable")) + ctx.sql("SELECT * FROM createdJsonTable"), + ctx.sql("SELECT b FROM savedJsonTable")) - sql("DROP TABLE createdJsonTable") + ctx.sql("DROP TABLE createdJsonTable") assert( intercept[RuntimeException] { - createExternalTable( + ctx.createExternalTable( "createdJsonTable", "org.apache.spark.sql.json", schema, @@ -569,16 +567,16 @@ class MetastoreDataSourcesSuite (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") withTable("test_parquet_ctas") { - sql( + ctx.sql( """CREATE TABLE test_parquet_ctas STORED AS PARQUET |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 """.stripMargin) checkAnswer( - sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + ctx.sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), Row(3) :: Row(4) :: Nil) - table("test_parquet_ctas").queryExecution.optimizedPlan match { + ctx.table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(p: ParquetRelation) => // OK case _ => fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation]}") @@ -634,10 +632,10 @@ class MetastoreDataSourcesSuite .mode(SaveMode.Append) .saveAsTable("arrayInParquet") - refreshTable("arrayInParquet") + ctx.refreshTable("arrayInParquet") checkAnswer( - sql("SELECT a FROM arrayInParquet"), + ctx.sql("SELECT a FROM arrayInParquet"), Row(ArrayBuffer(1, null)) :: Row(ArrayBuffer(2, 3)) :: Row(ArrayBuffer(4, 5)) :: @@ -693,10 +691,10 @@ class MetastoreDataSourcesSuite .mode(SaveMode.Append) .saveAsTable("mapInParquet") - refreshTable("mapInParquet") + ctx.refreshTable("mapInParquet") checkAnswer( - sql("SELECT a FROM mapInParquet"), + ctx.sql("SELECT a FROM mapInParquet"), Row(Map(1 -> null)) :: Row(Map(2 -> 3)) :: Row(Map(4 -> 5)) :: @@ -711,7 +709,7 @@ class MetastoreDataSourcesSuite val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) // Manually create a metastore data source table. - catalog.createDataSourceTable( + ctx.catalog.createDataSourceTable( tableName = "wide_schema", userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -719,9 +717,9 @@ class MetastoreDataSourcesSuite options = Map("path" -> "just a dummy path"), isExternal = false) - invalidateTable("wide_schema") + ctx.invalidateTable("wide_schema") - val actualSchema = table("wide_schema").schema + val actualSchema = ctx.table("wide_schema").schema assert(schema === actualSchema) } } @@ -742,12 +740,12 @@ class MetastoreDataSourcesSuite "EXTERNAL" -> "FALSE"), tableType = ManagedTable, serdeProperties = Map( - "path" -> catalog.hiveDefaultTableFilePath(tableName))) + "path" -> ctx.catalog.hiveDefaultTableFilePath(tableName))) - catalog.client.createTable(hiveTable) + ctx.catalog.client.createTable(hiveTable) - invalidateTable(tableName) - val actualSchema = table(tableName).schema + ctx.invalidateTable(tableName) + val actualSchema = ctx.table(tableName).schema assert(schema === actualSchema) } } @@ -758,8 +756,8 @@ class MetastoreDataSourcesSuite withTable(tableName) { df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) - invalidateTable(tableName) - val metastoreTable = catalog.client.getTable("default", tableName) + ctx.invalidateTable(tableName) + val metastoreTable = ctx.catalog.client.getTable("default", tableName) val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) val actualPartitionColumns = StructType( @@ -773,7 +771,7 @@ class MetastoreDataSourcesSuite // Check the content of the saved table. checkAnswer( - table(tableName).select("c", "b", "d", "a"), + ctx.table(tableName).select("c", "b", "d", "a"), df.select("c", "b", "d", "a")) } } @@ -786,7 +784,7 @@ class MetastoreDataSourcesSuite withTable("insertParquet") { createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + ctx.sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), (6 to 9).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { @@ -795,12 +793,12 @@ class MetastoreDataSourcesSuite createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + ctx.sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), (6 to 19).map(i => Row(i, s"str$i"))) createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + ctx.sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), (6 to 24).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { @@ -809,26 +807,26 @@ class MetastoreDataSourcesSuite createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + ctx.sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), (6 to 34).map(i => Row(i, s"str$i"))) createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + ctx.sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), (6 to 44).map(i => Row(i, s"str$i"))) createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + ctx.sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), (52 to 54).map(i => Row(i, s"str$i"))) createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + ctx.sql("SELECT p.c1, c2 FROM insertParquet p"), (50 to 59).map(i => Row(i, s"str$i"))) createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + ctx.sql("SELECT p.c1, c2 FROM insertParquet p"), (70 to 79).map(i => Row(i, s"str$i"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 6321c12777d3..997eda8fde36 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -21,10 +21,7 @@ import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.sql.{QueryTest, SaveMode} class MultiDatabaseSuite extends QueryTest with HiveTestUtils { - private val ctx = hiveContext - import ctx.sql - - private val df = ctx.range(10).coalesce(1) + private lazy val df = ctx.range(10).coalesce(1) test(s"saveAsTable() to non-default database - with USE - Overwrite") { withTempDatabase { db => @@ -99,7 +96,7 @@ class MultiDatabaseSuite extends QueryTest with HiveTestUtils { test("Looks up tables in non-default database") { withTempDatabase { db => activateDatabase(db) { - sql("CREATE TABLE t (key INT)") + ctx.sql("CREATE TABLE t (key INT)") checkAnswer(ctx.table("t"), ctx.emptyDataFrame) } @@ -110,7 +107,7 @@ class MultiDatabaseSuite extends QueryTest with HiveTestUtils { test("Drops a table in a non-default database") { withTempDatabase { db => activateDatabase(db) { - sql(s"CREATE TABLE t (key INT)") + ctx.sql(s"CREATE TABLE t (key INT)") assert(ctx.tableNames().contains("t")) assert(!ctx.tableNames("default").contains("t")) } @@ -119,7 +116,7 @@ class MultiDatabaseSuite extends QueryTest with HiveTestUtils { assert(ctx.tableNames(db).contains("t")) activateDatabase(db) { - sql(s"DROP TABLE t") + ctx.sql(s"DROP TABLE t") assert(!ctx.tableNames().contains("t")) assert(!ctx.tableNames("default").contains("t")) } @@ -137,7 +134,7 @@ class MultiDatabaseSuite extends QueryTest with HiveTestUtils { val path = dir.getCanonicalPath activateDatabase(db) { - sql( + ctx.sql( s"""CREATE EXTERNAL TABLE t (id BIGINT) |PARTITIONED BY (p INT) |STORED AS PARQUET @@ -147,8 +144,8 @@ class MultiDatabaseSuite extends QueryTest with HiveTestUtils { checkAnswer(ctx.table("t"), ctx.emptyDataFrame) df.write.parquet(s"$path/p=1") - sql("ALTER TABLE t ADD PARTITION (p=1)") - sql("REFRESH TABLE t") + ctx.sql("ALTER TABLE t ADD PARTITION (p=1)") + ctx.sql("REFRESH TABLE t") checkAnswer(ctx.table("t"), df.withColumn("p", lit(1))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index d14db04eb990..1a5527cdfac6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -26,8 +26,6 @@ import org.apache.spark.sql.{Row, SQLConf} class ParquetHiveCompatibilitySuite extends HiveParquetCompatibilityTest { import ParquetCompatibilityTest.makeNullable - private val ctx = hiveContext - /** * Set the staging directory (and hence path to ignore Parquet files under) * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 0f3a5e16088e..5b9bbddd0dc0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -1,3 +1,4 @@ + /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -19,15 +20,14 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.test.SQLTestData.TestData import org.apache.spark.sql.QueryTest import org.apache.spark.util.Utils -class QueryPartitionSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ - import ctx.sql +class QueryPartitionSuite extends QueryTest with HiveTestUtils { + import testImplicits._ test("SPARK-5068: query data when path doesn't exist"){ val testData = ctx.sparkContext.parallelize( @@ -36,19 +36,19 @@ class QueryPartitionSuite extends QueryTest with SharedHiveContext { val tmpDir = Files.createTempDir() // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + + ctx.sql(s"CREATE TABLE table_with_partition(key int,value string) " + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + ctx.sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + ctx.sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + ctx.sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + ctx.sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + "SELECT key,value FROM testData") // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), + checkAnswer(ctx.sql("select key,value from table_with_partition"), testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) @@ -58,10 +58,10 @@ class QueryPartitionSuite extends QueryTest with SharedHiveContext { .foreach { f => Utils.deleteRecursively(f) } // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), + checkAnswer(ctx.sql("select key,value from table_with_partition"), testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - sql("DROP TABLE table_with_partition") - sql("DROP TABLE createAndInsertTest") + ctx.sql("DROP TABLE table_with_partition") + ctx.sql("DROP TABLE createAndInsertTest") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index 376dc7ebdd4a..b7a8ba493687 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.hive import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils -class SerializationSuite extends SparkFunSuite with SharedHiveContext { - private val ctx = hiveContext +class SerializationSuite extends SparkFunSuite with HiveTestUtils { test("[SPARK-5840] HiveContext should be serializable") { ctx.hiveconf diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ba78cd932631..f6963c917d27 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,14 +22,15 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils -class StatisticsSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx.sql +class StatisticsSuite extends QueryTest with HiveTestUtils { - ctx.reset() - ctx.cacheTables = false + protected override def beforeAll(): Unit = { + super.beforeAll() + ctx.reset() + ctx.cacheTables = false + } test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -78,32 +79,32 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table - sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() - sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + ctx.sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() + ctx.sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + ctx.sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") + ctx.sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) - sql("DROP TABLE analyzeTable").collect() + ctx.sql("DROP TABLE analyzeTable").collect() // Partitioned table - sql( + ctx.sql( """ |CREATE TABLE analyzeTable_part (key STRING, value STRING) PARTITIONED BY (ds STRING) """.stripMargin).collect() - sql( + ctx.sql( """ |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-01') |SELECT * FROM src """.stripMargin).collect() - sql( + ctx.sql( """ |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-02') |SELECT * FROM src """.stripMargin).collect() - sql( + ctx.sql( """ |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-03') |SELECT * FROM src @@ -111,14 +112,14 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) - sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") + ctx.sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) - sql("DROP TABLE analyzeTable_part").collect() + ctx.sql("DROP TABLE analyzeTable_part").collect() // Try to analyze a temp table - sql("""SELECT * FROM src""").registerTempTable("tempTable") + ctx.sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { ctx.analyze("tempTable") } @@ -126,7 +127,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { } test("estimates the size of a test MetastoreRelation") { - val df = sql("""SELECT * FROM src""") + val df = ctx.sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } @@ -144,7 +145,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { ct: ClassTag[_]): Unit = { before() - var df = sql(query) + var df = ctx.sql(query) // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { @@ -165,8 +166,8 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { ctx.conf.settings.synchronized { val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") - df = sql(query) + ctx.sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") + df = ctx.sql(query) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") @@ -174,7 +175,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") + ctx.sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") } after() @@ -198,7 +199,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin val answer = Row(86, "val_86") - var df = sql(leftSemiJoinQuery) + var df = ctx.sql(leftSemiJoinQuery) // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { @@ -223,8 +224,8 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { ctx.conf.settings.synchronized { val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - df = sql(leftSemiJoinQuery) + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") + df = ctx.sql(leftSemiJoinQuery) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j } @@ -236,7 +237,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { assert(shj.size === 1, "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 7355c62d2c8b..954e4201d923 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext +class UDFSuite extends QueryTest with HiveTestUtils { test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index b13c0cb82a23..2989680fd390 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with HiveTestUtils { - protected final val ctx = hiveContext - import ctx.implicits._ + import testImplicits._ var originalUseAggregate2: Boolean = _ @@ -508,14 +507,14 @@ class SortBasedAggregationQuerySuite extends AggregationQuerySuite { var originalUnsafeEnabled: Boolean = _ override def beforeAll(): Unit = { + super.beforeAll() originalUnsafeEnabled = ctx.conf.unsafeEnabled ctx.setConf(SQLConf.UNSAFE_ENABLED.key, "false") - super.beforeAll() } override def afterAll(): Unit = { - super.afterAll() ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + super.afterAll() } } @@ -524,13 +523,13 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite { var originalUnsafeEnabled: Boolean = _ override def beforeAll(): Unit = { + super.beforeAll() originalUnsafeEnabled = ctx.conf.unsafeEnabled ctx.setConf(SQLConf.UNSAFE_ENABLED.key, "true") - super.beforeAll() } override def afterAll(): Unit = { - super.afterAll() ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + super.afterAll() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index 6f26e2119c66..5182481297a8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -19,22 +19,26 @@ package org.apache.spark.sql.hive.execution import java.io.File -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHiveContext.TestTable /** * A set of test cases based on the big-data-benchmark. * https://amplab.cs.berkeley.edu/benchmark/ */ -class BigDataBenchmarkSuite extends HiveComparisonTest with SharedHiveContext { - import ctx._ +class BigDataBenchmarkSuite extends HiveComparisonTest { private val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") private val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath - private val testTables = Seq( - TestTable( - "rankings", - s""" + + protected override def beforeAll(): Unit = { + super.beforeAll() + val _ctx = ctx + import _ctx._ + val testTables = Seq( + TestTable( + "rankings", + s""" |CREATE EXTERNAL TABLE rankings ( | pageURL STRING, | pageRank INT, @@ -42,9 +46,9 @@ class BigDataBenchmarkSuite extends HiveComparisonTest with SharedHiveContext { | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "rankings").getCanonicalPath}" """.stripMargin.cmd), - TestTable( - "scratch", - s""" + TestTable( + "scratch", + s""" |CREATE EXTERNAL TABLE scratch ( | pageURL STRING, | pageRank INT, @@ -52,9 +56,9 @@ class BigDataBenchmarkSuite extends HiveComparisonTest with SharedHiveContext { | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "scratch").getCanonicalPath}" """.stripMargin.cmd), - TestTable( - "uservisits", - s""" + TestTable( + "uservisits", + s""" |CREATE EXTERNAL TABLE uservisits ( | sourceIP STRING, | destURL STRING, @@ -68,15 +72,15 @@ class BigDataBenchmarkSuite extends HiveComparisonTest with SharedHiveContext { | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," | STORED AS TEXTFILE LOCATION "$userVisitPath" """.stripMargin.cmd), - TestTable( - "documents", - s""" + TestTable( + "documents", + s""" |CREATE EXTERNAL TABLE documents (line STRING) |STORED AS TEXTFILE |LOCATION "${new File(testDataDirectory, "crawl").getCanonicalPath}" """.stripMargin.cmd)) - - testTables.foreach(registerTestTable) + testTables.foreach(registerTestTable) + } if (!testDataDirectory.exists()) { // TODO: Auto download the files on demand. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index a001bba46f00..e1c883cd9f10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -22,12 +22,13 @@ import java.io._ import org.scalatest.GivenWhenThen import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils /** * Allows the creations of tests that execute the same query against both hive @@ -42,11 +43,9 @@ import org.apache.spark.sql.hive.test.SharedHiveContext abstract class HiveComparisonTest extends SparkFunSuite with GivenWhenThen - with SharedHiveContext + with HiveTestUtils with Logging { - protected val ctx = hiveContext - /** * When set, any cache files that result in test failures will be deleted. Used when the test * harness or hive have been updated thus requiring new golden answers to be computed for some @@ -133,9 +132,9 @@ abstract class HiveComparisonTest new java.math.BigInteger(1, digest.digest).toString(16) } - protected def prepareAnswer( - hiveQuery: ctx.type#QueryExecution, - answer: Seq[String]): Seq[String] = { + private def prepareAnswer(ctx: SQLContext)( + hiveQuery: ctx.type#QueryExecution, + answer: Seq[String]): Seq[String] = { def isSorted(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false @@ -272,6 +271,8 @@ abstract class HiveComparisonTest }.mkString("\n== Console version of this test ==\n", "\n", "\n") } + val _ctx = ctx + try { if (reset) { ctx.reset() @@ -303,7 +304,7 @@ abstract class HiveComparisonTest hiveCachedResults } else { - val hiveQueries = queryList.map(new ctx.QueryExecution(_)) + val hiveQueries = queryList.map(new _ctx.QueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. // Note this must only look at the logical plan as we might not be able to analyze if // other DDL has not been executed yet. @@ -353,8 +354,8 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new ctx.QueryExecution(queryString) - try { (query, prepareAnswer(query, query.stringResult())) } catch { + val query = new _ctx.QueryExecution(queryString) + try { (query, prepareAnswer(_ctx)(query, query.stringResult())) } catch { case e: Throwable => val errorMessage = s""" @@ -373,7 +374,7 @@ abstract class HiveComparisonTest (queryList, hiveResults, catalystResults).zipped.foreach { case (query, hive, (hiveQuery, catalyst)) => // Check that the results match unless its an EXPLAIN query. - val preparedHive = prepareAnswer(hiveQuery, hive) + val preparedHive = prepareAnswer(_ctx)(hiveQuery, hive) // We will ignore the ExplainCommand, ShowFunctions, DescribeFunction if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && @@ -413,7 +414,7 @@ abstract class HiveComparisonTest // okay by running a simple query. If this fails then we halt testing since // something must have gone seriously wrong. try { - new ctx.QueryExecution("SELECT key FROM src").stringResult() + new _ctx.QueryExecution("SELECT key FROM src").stringResult() ctx.runSqlHive("SELECT key FROM src") } catch { case e: Exception => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 3cd8ed3125d5..e0ca8a9fcd7c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,23 +18,21 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx._ +class HiveExplainSuite extends QueryTest with HiveTestUtils { test("explain extended command") { - checkExistence(sql(" explain select * from src where key=123 "), true, + checkExistence(ctx.sql(" explain select * from src where key=123 "), true, "== Physical Plan ==") - checkExistence(sql(" explain select * from src where key=123 "), false, + checkExistence(ctx.sql(" explain select * from src where key=123 "), false, "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==") - checkExistence(sql(" explain extended select * from src where key=123 "), true, + checkExistence(ctx.sql(" explain extended select * from src where key=123 "), true, "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", @@ -43,13 +41,13 @@ class HiveExplainSuite extends QueryTest with SharedHiveContext { } test("explain create table command") { - checkExistence(sql("explain create table temp__b as select * from src limit 2"), true, + checkExistence(ctx.sql("explain create table temp__b as select * from src limit 2"), true, "== Physical Plan ==", "InsertIntoHiveTable", "Limit", "src") - checkExistence(sql("explain extended create table temp__b as select * from src limit 2"), true, + checkExistence(ctx.sql("explain extended create table temp__b as select * from src limit 2"), true, "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", @@ -59,7 +57,7 @@ class HiveExplainSuite extends QueryTest with SharedHiveContext { "Limit", "src") - checkExistence(sql( + checkExistence(ctx.sql( """ | EXPLAIN EXTENDED CREATE TABLE temp__b | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index bd7a456fcf50..a35548a6a979 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -18,35 +18,33 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils /** * A set of tests that validates commands can also be queried by like a table */ -class HiveOperatorQueryableSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx._ +class HiveOperatorQueryableSuite extends QueryTest with HiveTestUtils { test("SPARK-5324 query result of describe command") { - loadTestTable("src") + ctx.loadTestTable("src") // register a describe command to be a temp table - sql("desc src").registerTempTable("mydesc") + ctx.sql("desc src").registerTempTable("mydesc") checkAnswer( - sql("desc mydesc"), + ctx.sql("desc mydesc"), Seq( Row("col_name", "string", "name of the column"), Row("data_type", "string", "data type of the column"), Row("comment", "string", "comment of the column"))) checkAnswer( - sql("select * from mydesc"), + ctx.sql("select * from mydesc"), Seq( Row("key", "int", null), Row("value", "string", null))) checkAnswer( - sql("select col_name, data_type, comment from mydesc"), + ctx.sql("select col_name, data_type, comment from mydesc"), Seq( Row("key", "int", null), Row("value", "string", null))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index 1d0ce8e304ad..3b9ed9107363 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -21,17 +21,15 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils -class HivePlanTest extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ +class HivePlanTest extends QueryTest with HiveTestUtils { + import testImplicits._ test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") - val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan + val optimized = ctx.sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan + val correctAnswer = ctx.sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 3380129b29a4..af254ccaf050 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,40 +30,40 @@ import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.test.SQLTestData.TestData -case class TestData(a: Int, b: String) /** * A set of test cases expressed in Hive QL that are not covered by the tests * included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedHiveContext { - import ctx.implicits._ - import ctx._ +class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { + import testImplicits._ private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault - override def beforeAll() { - cacheTables = true + override def beforeAll(): Unit = { + super.beforeAll() + ctx.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) } - override def afterAll() { - cacheTables = false + override def afterAll(): Unit = { + ctx.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - sql("DROP TEMPORARY FUNCTION udtf_count2") + ctx.sql("DROP TEMPORARY FUNCTION udtf_count2") + super.afterAll() } test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => - sql("USE default") - sql("SHOW DATABASES") + ctx.sql("USE default") + ctx.sql("SHOW DATABASES") } } @@ -147,11 +147,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH test("multiple generators in projection") { intercept[AnalysisException] { - sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() + ctx.sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() } intercept[AnalysisException] { - sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() + ctx.sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() } } @@ -241,8 +241,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH """.stripMargin) test("CREATE TABLE AS runs once") { - sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() - assert(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + ctx.sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() + assert(ctx.sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, "Incorrect number of rows in created table") } @@ -254,7 +254,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH // Jdk version leads to different query output for double, so not use createQueryTest here test("division") { - val res = sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head + val res = ctx.sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res.toSeq).foreach( x => assert(x._1 == x._2.asInstanceOf[Double])) } @@ -264,17 +264,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH "(101 / 2) % 10 FROM src LIMIT 1") test("Query expressed in SQL") { - setConf("spark.sql.dialect", "sql") - assert(sql("SELECT 1").collect() === Array(Row(1))) - setConf("spark.sql.dialect", "hiveql") + ctx.setConf("spark.sql.dialect", "sql") + assert(ctx.sql("SELECT 1").collect() === Array(Row(1))) + ctx.setConf("spark.sql.dialect", "hiveql") } test("Query expressed in HiveQL") { - sql("FROM src SELECT key").collect() + ctx.sql("FROM src SELECT key").collect() } test("Query with constant folding the CAST") { - sql("SELECT CAST(CAST('123' AS binary) AS binary) FROM src LIMIT 1").collect() + ctx.sql("SELECT CAST(CAST('123' AS binary) AS binary) FROM src LIMIT 1").collect() } createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT", @@ -373,10 +373,10 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH """.stripMargin) test("SPARK-7270: consider dynamic partition when comparing table output") { - sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") - sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") + ctx.sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") + ctx.sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") - val analyzedPlan = sql( + val analyzedPlan = ctx.sql( """ |INSERT OVERWRITE table test_partition PARTITION (b=1, c) |SELECT 'a', 'c' from ptest @@ -430,11 +430,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH test("transform with SerDe2") { - sql("CREATE TABLE small_src(key INT, value STRING)") - sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10") + ctx.sql("CREATE TABLE small_src(key INT, value STRING)") + ctx.sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10") - val expected = sql("SELECT key FROM small_src").collect().head - val res = sql( + val expected = ctx.sql("SELECT key FROM small_src").collect().head + val res = ctx.sql( """ |SELECT TRANSFORM (key) ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.avro.AvroSerDe' @@ -508,13 +508,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") test("sampling") { - sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") - sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") + ctx.sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + ctx.sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") } test("DataFrame toString") { - sql("SHOW TABLES").toString - sql("SELECT * FROM src").toString + ctx.sql("SHOW TABLES").toString + ctx.sql("SELECT * FROM src").toString } createQueryTest("case statements with key #1", @@ -543,7 +543,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH // Jdk version leads to different query output for double, so not use createQueryTest here test("timestamp cast #1") { - val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + val res = ctx.sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head assert(0.001 == res.getDouble(0)) } @@ -587,7 +587,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH """.stripMargin) test("predicates contains an empty AttributeSet() references") { - sql( + ctx.sql( """ |SELECT a FROM ( | SELECT 1 AS a FROM src LIMIT 1 ) t @@ -596,12 +596,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH } test("implement identity function using case statement") { - val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") + val actual = ctx.sql("SELECT (CASE key WHEN key THEN key END) FROM src") .map { case Row(i: Int) => i } .collect() .toSet - val expected = sql("SELECT key FROM src") + val expected = ctx.sql("SELECT key FROM src") .map { case Row(i: Int) => i } .collect() .toSet @@ -613,7 +613,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH // See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion. ignore("non-boolean conditions in a CaseWhen are illegal") { intercept[Exception] { - sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() + ctx.sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() } } @@ -622,13 +622,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH test("case sensitivity: registered table") { val testData = - sparkContext.parallelize( + ctx.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) testData.toDF().registerTempTable("REGisteredTABle") assertResult(Array(Row(2, "str2"))) { - sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + + ctx.sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + "WHERE TableAliaS.a > 1").collect() } } @@ -639,92 +639,92 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH } test("SPARK-1704: Explain commands as a DataFrame") { - sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + ctx.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - val df = sql("explain select key, count(value) from src group by key") + val df = ctx.sql("explain select key, count(value) from src group by key") assert(isExplanation(df)) - reset() + ctx.reset() } test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") + ctx.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") val results = - sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + ctx.sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() .map(x => Pair(x.getString(0), x.getInt(1))) assert(results === Array(Pair("foo", 4))) - reset() + ctx.reset() } test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { - sql("select key, count(*) c from src group by key having c").collect() + ctx.sql("select key, count(*) c from src group by key having c").collect() } test("SPARK-2225: turn HAVING without GROUP BY into a simple filter") { - assert(sql("select key from src having key > 490").collect().size < 100) + assert(ctx.sql("select key from src having key > 490").collect().size < 100) } test("SPARK-5383 alias for udfs with multi output columns") { assert( - sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") + ctx.sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") .collect() .size == 5) assert( - sql("select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") + ctx.sql("select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") .collect() .size == 5) } test("SPARK-5367: resolve star expression in udf") { - assert(sql("select concat(*) from src limit 5").collect().size == 5) - assert(sql("select array(*) from src limit 5").collect().size == 5) - assert(sql("select concat(key, *) from src limit 5").collect().size == 5) - assert(sql("select array(key, *) from src limit 5").collect().size == 5) + assert(ctx.sql("select concat(*) from src limit 5").collect().size == 5) + assert(ctx.sql("select array(*) from src limit 5").collect().size == 5) + assert(ctx.sql("select concat(key, *) from src limit 5").collect().size == 5) + assert(ctx.sql("select array(key, *) from src limit 5").collect().size == 5) } test("Query Hive native command execution result") { val databaseName = "test_native_commands" assertResult(0) { - sql(s"DROP DATABASE IF EXISTS $databaseName").count() + ctx.sql(s"DROP DATABASE IF EXISTS $databaseName").count() } assertResult(0) { - sql(s"CREATE DATABASE $databaseName").count() + ctx.sql(s"CREATE DATABASE $databaseName").count() } assert( - sql("SHOW DATABASES") + ctx.sql("SHOW DATABASES") .select('result) .collect() .map(_.getString(0)) .contains(databaseName)) - assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key"))) + assert(isExplanation(ctx.sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key"))) - reset() + ctx.reset() } test("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" - val q0 = sql(s"CREATE TABLE $tableName(key INT, value STRING)") + val q0 = ctx.sql(s"CREATE TABLE $tableName(key INT, value STRING)") // If the table was not created, the following assertion would fail - assert(Try(table(tableName)).isSuccess) + assert(Try(ctx.table(tableName)).isSuccess) // If the CREATE TABLE command got executed again, the following assertion would fail assert(Try(q0.count()).isSuccess) } test("DESCRIBE commands") { - sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") + ctx.sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") - sql( + ctx.sql( """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') |SELECT key, value """.stripMargin) @@ -739,7 +739,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH Row("# col_name", "data_type", "comment"), Row("dt", "string", null)) ) { - sql("DESCRIBE test_describe_commands1") + ctx.sql("DESCRIBE test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } @@ -754,14 +754,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH Row("# col_name", "data_type", "comment"), Row("dt", "string", null)) ) { - sql("DESCRIBE default.test_describe_commands1") + ctx.sql("DESCRIBE default.test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - sql("DESCRIBE test_describe_commands1 value") + ctx.sql("DESCRIBE test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -769,7 +769,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - sql("DESCRIBE default.test_describe_commands1 value") + ctx.sql("DESCRIBE default.test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -787,7 +787,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH Array(""), Array("dt", "string")) ) { - sql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") + ctx.sql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") .select('result) .collect() .map(_.getString(0).replaceAll("None", "").trim.split("\t").map(_.trim)) @@ -795,7 +795,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH // Describe a registered temporary table. val testData = - sparkContext.parallelize( + ctx.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) testData.toDF().registerTempTable("test_describe_commands2") @@ -805,16 +805,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH Row("a", "int", ""), Row("b", "string", "")) ) { - sql("DESCRIBE test_describe_commands2") + ctx.sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) .collect() } } test("SPARK-2263: Insert Map values") { - sql("CREATE TABLE m(value MAP)") - sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { + ctx.sql("CREATE TABLE m(value MAP)") + ctx.sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") + ctx.sql("SELECT * FROM m").collect().zip(ctx.sql("SELECT * FROM src LIMIT 10").collect()).map { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -822,35 +822,35 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH } test("ADD JAR command") { - val testJar = getHiveFile("data/files/TestSerDe.jar").getCanonicalPath - sql("CREATE TABLE alter1(a INT, b INT)") + val testJar = ctx.getHiveFile("data/files/TestSerDe.jar").getCanonicalPath + ctx.sql("CREATE TABLE alter1(a INT, b INT)") intercept[Exception] { - sql( + ctx.sql( """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' |WITH serdeproperties('s1'='9') """.stripMargin) } - sql("DROP TABLE alter1") + ctx.sql("DROP TABLE alter1") } test("ADD JAR command 2") { // this is a test case from mapjoin_addjar.q - val testJar = getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath - val testData = getHiveFile("data/files/sample.json").getCanonicalPath - sql(s"ADD JAR $testJar") - sql( + val testJar = ctx.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val testData = ctx.getHiveFile("data/files/sample.json").getCanonicalPath + ctx.sql(s"ADD JAR $testJar") + ctx.sql( """CREATE TABLE t1(a string, b string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) - sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") - sql("select * from src join t1 on src.key = t1.a") - sql("DROP TABLE t1") + ctx.sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + ctx.sql("select * from src join t1 on src.key = t1.a") + ctx.sql("DROP TABLE t1") } test("ADD FILE command") { - val testFile = getHiveFile("data/files/v1.txt").getCanonicalFile - sql(s"ADD FILE $testFile") + val testFile = ctx.getHiveFile("data/files/v1.txt").getCanonicalFile + ctx.sql(s"ADD FILE $testFile") - val checkAddFileRDD = sparkContext.parallelize(1 to 2, 1).mapPartitions { _ => + val checkAddFileRDD = ctx.sparkContext.parallelize(1 to 2, 1).mapPartitions { _ => Iterator.single(new File(SparkFiles.get("v1.txt")).canRead) } @@ -883,9 +883,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH """.stripMargin) ignore("Dynamic partition folder layout") { - sql("DROP TABLE IF EXISTS dynamic_part_table") - sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") - sql("SET hive.exec.dynamic.partition.mode=nonstrict") + ctx.sql("DROP TABLE IF EXISTS dynamic_part_table") + ctx.sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") + ctx.sql("SET hive.exec.dynamic.partition.mode=nonstrict") val data = Map( Seq("1", "1") -> 1, @@ -894,7 +894,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH Seq("NULL", "NULL") -> 4) data.foreach { case (parts, value) => - sql( + ctx.sql( s"""INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) |SELECT $value, ${parts.mkString(", ")} FROM src WHERE key=150 """.stripMargin) @@ -911,18 +911,18 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH .mkString("/") // Loads partition data to a temporary table to verify contents - val path = s"$warehousePath/dynamic_part_table/$partFolder/part-00000" + val path = s"${ctx.warehousePath}/dynamic_part_table/$partFolder/part-00000" - sql("DROP TABLE IF EXISTS dp_verify") - sql("CREATE TABLE dp_verify(intcol INT)") - sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") + ctx.sql("DROP TABLE IF EXISTS dp_verify") + ctx.sql("CREATE TABLE dp_verify(intcol INT)") + ctx.sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") - assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + assert(ctx.sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) } } test("SPARK-5592: get java.net.URISyntaxException when dynamic partitioning") { - sql(""" + ctx.sql(""" |create table sc as select * |from (select '2011-01-11', '2011-01-11+14:18:26' from src tablesample (1 rows) |union all @@ -930,31 +930,31 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH |union all |select '2011-01-11', '2011-01-11+16:18:26' from src tablesample (1 rows) ) s """.stripMargin) - sql("create table sc_part (key string) partitioned by (ts string) stored as rcfile") - sql("set hive.exec.dynamic.partition=true") - sql("set hive.exec.dynamic.partition.mode=nonstrict") - sql("insert overwrite table sc_part partition(ts) select * from sc") - sql("drop table sc_part") + ctx.sql("create table sc_part (key string) partitioned by (ts string) stored as rcfile") + ctx.sql("set hive.exec.dynamic.partition=true") + ctx.sql("set hive.exec.dynamic.partition.mode=nonstrict") + ctx.sql("insert overwrite table sc_part partition(ts) select * from sc") + ctx.sql("drop table sc_part") } test("Partition spec validation") { - sql("DROP TABLE IF EXISTS dp_test") - sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") - sql("SET hive.exec.dynamic.partition.mode=strict") + ctx.sql("DROP TABLE IF EXISTS dp_test") + ctx.sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") + ctx.sql("SET hive.exec.dynamic.partition.mode=strict") // Should throw when using strict dynamic partition mode without any static partition intercept[SparkException] { - sql( + ctx.sql( """INSERT INTO TABLE dp_test PARTITION(dp) |SELECT key, value, key % 5 FROM src """.stripMargin) } - sql("SET hive.exec.dynamic.partition.mode=nonstrict") + ctx.sql("SET hive.exec.dynamic.partition.mode=nonstrict") // Should throw when a static partition appears after a dynamic partition intercept[SparkException] { - sql( + ctx.sql( """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) |SELECT key, value, key % 5 FROM src """.stripMargin) @@ -962,10 +962,10 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH } test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { - sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().registerTempTable("rawLogs") - sparkContext.makeRDD(Seq.empty[LogFile]).toDF().registerTempTable("logFiles") + ctx.sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().registerTempTable("rawLogs") + ctx.sparkContext.makeRDD(Seq.empty[LogFile]).toDF().registerTempTable("logFiles") - sql( + ctx.sql( """ SELECT name, message FROM rawLogs @@ -977,15 +977,15 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH """).registerTempTable("boom") // This should be successfully analyzed - sql("SELECT * FROM boom").queryExecution.analyzed + ctx.sql("SELECT * FROM boom").queryExecution.analyzed } test("SPARK-3810: PreInsertionCasts static partitioning support") { val analyzedPlan = { - loadTestTable("srcpart") - sql("DROP TABLE IF EXISTS withparts") - sql("CREATE TABLE withparts LIKE srcpart") - sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") + ctx.loadTestTable("srcpart") + ctx.sql("DROP TABLE IF EXISTS withparts") + ctx.sql("CREATE TABLE withparts LIKE srcpart") + ctx.sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") .queryExecution.analyzed } @@ -998,13 +998,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH test("SPARK-3810: PreInsertionCasts dynamic partitioning support") { val analyzedPlan = { - loadTestTable("srcpart") - sql("DROP TABLE IF EXISTS withparts") - sql("CREATE TABLE withparts LIKE srcpart") - sql("SET hive.exec.dynamic.partition.mode=nonstrict") + ctx.loadTestTable("srcpart") + ctx.sql("DROP TABLE IF EXISTS withparts") + ctx.sql("CREATE TABLE withparts LIKE srcpart") + ctx.sql("SET hive.exec.dynamic.partition.mode=nonstrict") - sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") - sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + ctx.sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + ctx.sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") .queryExecution.analyzed } @@ -1020,19 +1020,19 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH val testKey = "spark.sql.key.usedfortestonly" val testVal = "val0,val_1,val2.3,my_table" - sql(s"set $testKey=$testVal") - assert(getConf(testKey, testVal + "_") == testVal) + ctx.sql(s"set $testKey=$testVal") + assert(ctx.getConf(testKey, testVal + "_") == testVal) - sql("set some.property=20") - assert(getConf("some.property", "0") == "20") - sql("set some.property = 40") - assert(getConf("some.property", "0") == "40") + ctx.sql("set some.property=20") + assert(ctx.getConf("some.property", "0") == "20") + ctx.sql("set some.property = 40") + assert(ctx.getConf("some.property", "0") == "40") - sql(s"set $testKey=$testVal") - assert(getConf(testKey, "0") == testVal) + ctx.sql(s"set $testKey=$testVal") + assert(ctx.getConf(testKey, "0") == testVal) - sql(s"set $testKey=") - assert(getConf(testKey, "0") == "") + ctx.sql(s"set $testKey=") + assert(ctx.getConf(testKey, "0") == "") } test("SET commands semantics for a HiveContext") { @@ -1045,38 +1045,38 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter with SharedH case Row(key: String, value: String) => key -> value case Row(key: String, defaultValue: String, doc: String) => (key, defaultValue, doc) }.toSet - conf.clear() + ctx.conf.clear() - val expectedConfs = conf.getAllDefinedConfs.toSet - assertResult(expectedConfs)(collectResults(sql("SET -v"))) + val expectedConfs = ctx.conf.getAllDefinedConfs.toSet + assertResult(expectedConfs)(collectResults(ctx.sql("SET -v"))) // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... - assert(sql("SET").collect().size == 0) + assert(ctx.sql("SET").collect().size == 0) assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) + collectResults(ctx.sql(s"SET $testKey=$testVal")) } - assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) + assert(ctx.hiveconf.get(testKey, "") == testVal) + assertResult(Set(testKey -> testVal))(collectResults(ctx.sql("SET"))) - sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + ctx.sql(s"SET ${testKey + testKey}=${testVal + testVal}") + assert(ctx.hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + collectResults(ctx.sql("SET")) } // "SET key" assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) + collectResults(ctx.sql(s"SET $testKey")) } assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) + collectResults(ctx.sql(s"SET $nonexistentKey")) } - conf.clear() + ctx.conf.clear() } createQueryTest("select from thrift based table", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 915256c5d3cc..fa69c3b84c02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -27,23 +27,22 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) * included in the hive distribution. */ class HiveResolutionSuite extends HiveComparisonTest { - import ctx.implicits._ - import ctx._ + import testImplicits._ test("SPARK-3698: case insensitive test for nested data") { - read.json(sparkContext.makeRDD( + ctx.read.json(ctx.sparkContext.makeRDD( """{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested") // This should be successfully analyzed - sql("SELECT a[0].A.A from nested").queryExecution.analyzed + ctx.sql("SELECT a[0].A.A from nested").queryExecution.analyzed } test("SPARK-5278: check ambiguous reference to fields") { - read.json(sparkContext.makeRDD( + ctx.read.json(ctx.sparkContext.makeRDD( """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested") // there are 2 filed matching field name "b", we should report Ambiguous reference error val exception = intercept[AnalysisException] { - sql("SELECT a[0].b from nested").queryExecution.analyzed + ctx.sql("SELECT a[0].b from nested").queryExecution.analyzed } assert(exception.getMessage.contains("Ambiguous reference to fields")) } @@ -77,10 +76,10 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) + ctx.sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") - val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + val query = ctx.sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), "The output schema did not preserve the case of the query.") query.collect() @@ -88,16 +87,16 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) + ctx.sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") - sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() + ctx.sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { - sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) + ctx.sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("nestedRepeatedTest") - assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) + assert(ctx.sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } createQueryTest("test ambiguousReferences resolved as hive", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index c508c3b7d7df..ece3715e19f6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -17,23 +17,21 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.hive.test.SharedHiveContext - /** * A set of tests that validates support for Hive SerDe. */ -class HiveSerDeSuite extends HiveComparisonTest with SharedHiveContext { +class HiveSerDeSuite extends HiveComparisonTest { import org.apache.hadoop.hive.serde2.RegexSerDe - import ctx._ override def beforeAll(): Unit = { super.beforeAll() ctx.cacheTables = false - sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + ctx.sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") """.stripMargin) - sql(s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales") + ctx.sql( + s"LOAD DATA LOCAL INPATH '${ctx.getHiveFile("data/files/sales.txt")}' INTO TABLE sales") } // table sales is not a cache table, and will be clear after reset diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index dc2f56cb3369..5e8f9f961b89 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.functions._ import org.apache.spark.util.Utils class HiveTableScanSuite extends HiveComparisonTest { - import ctx.implicits._ - import ctx._ + import testImplicits._ createQueryTest("partition_based_table_scan_with_different_serde", """ @@ -54,15 +52,15 @@ class HiveTableScanSuite extends HiveComparisonTest { """.stripMargin) test("Spark-4041: lowercase issue") { - sql("CREATE TABLE tb (KEY INT, VALUE STRING) STORED AS ORC") - sql("insert into table tb select key, value from src") - sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() - sql("drop table tb") + ctx.sql("CREATE TABLE tb (KEY INT, VALUE STRING) STORED AS ORC") + ctx.sql("insert into table tb select key, value from src") + ctx.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() + ctx.sql("drop table tb") } test("Spark-4077: timestamp query for null value") { - sql("DROP TABLE IF EXISTS timestamp_query_null") - sql( + ctx.sql("DROP TABLE IF EXISTS timestamp_query_null") + ctx.sql( """ CREATE EXTERNAL TABLE timestamp_query_null (time TIMESTAMP,id INT) ROW FORMAT DELIMITED @@ -72,20 +70,20 @@ class HiveTableScanSuite extends HiveComparisonTest { val location = Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() - sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") - assert(sql("SELECT time from timestamp_query_null limit 2").collect() + ctx.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") + assert(ctx.sql("SELECT time from timestamp_query_null limit 2").collect() === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) - sql("DROP TABLE timestamp_query_null") + ctx.sql("DROP TABLE timestamp_query_null") } test("Spark-4959 Attributes are case sensitive when using a select query from a projection") { - sql("create table spark_4959 (col1 string)") - sql("""insert into table spark_4959 select "hi" from src limit 1""") - table("spark_4959").select( + ctx.sql("create table spark_4959 (col1 string)") + ctx.sql("""insert into table spark_4959 select "hi" from src limit 1""") + ctx.table("spark_4959").select( 'col1.as("CaseSensitiveColName"), 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2") - assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) - assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) + assert(ctx.sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) + assert(ctx.sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 2604ee5d2794..3e07d28354ba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.util.Utils case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) @@ -46,14 +46,12 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ +class HiveUDFSuite extends QueryTest with HiveTestUtils { + import testImplicits._ test("spark sql udf test that returns a struct") { - udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) - assert(sql( + ctx.udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) + assert(ctx.sql( """ |SELECT getStruct(1).f1, | getStruct(1).f2, @@ -65,13 +63,13 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { checkAnswer( - sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), + ctx.sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), Row(8) ) } test("hive struct udf") { - sql( + ctx.sql( """ |CREATE EXTERNAL TABLE hiveUDFTestTable ( | pair STRUCT @@ -83,25 +81,25 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { stripMargin.format(classOf[PairSerDe].getName)) val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile - sql(s""" + ctx.sql(s""" ALTER TABLE hiveUDFTestTable ADD IF NOT EXISTS PARTITION(partition='testUDF') LOCATION '$location'""") - sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") - sql("SELECT testUDF(pair) FROM hiveUDFTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") + ctx.sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + ctx.sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") } test("Max/Min on named_struct") { def testOrderInStruct(): Unit = { - checkAnswer(sql( + checkAnswer(ctx.sql( """ |SELECT max(named_struct( | "key", key, | "value", value)).value FROM src """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( + checkAnswer(ctx.sql( """ |SELECT min(named_struct( | "key", key, @@ -109,7 +107,7 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { """.stripMargin), Seq(Row("val_0"))) // nested struct cases - checkAnswer(sql( + checkAnswer(ctx.sql( """ |SELECT max(named_struct( | "key", named_struct( @@ -117,7 +115,7 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { "value", value), | "value", value)).value FROM src """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( + checkAnswer(ctx.sql( """ |SELECT min(named_struct( | "key", named_struct( @@ -126,176 +124,176 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { | "value", value)).value FROM src """.stripMargin), Seq(Row("val_0"))) } - val codegenDefault = getConf(SQLConf.CODEGEN_ENABLED) - setConf(SQLConf.CODEGEN_ENABLED, true) + val codegenDefault = ctx.getConf(SQLConf.CODEGEN_ENABLED) + ctx.setConf(SQLConf.CODEGEN_ENABLED, true) testOrderInStruct() - setConf(SQLConf.CODEGEN_ENABLED, false) + ctx.setConf(SQLConf.CODEGEN_ENABLED, false) testOrderInStruct() - setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + ctx.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) } test("SPARK-6409 UDAFAverage test") { - sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + ctx.sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") checkAnswer( - sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), + ctx.sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), Seq(Row(1.0, 260.182))) - sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") - reset() + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") + ctx.reset() } test("SPARK-2693 udaf aggregates test") { - checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), - sql("SELECT max(key) FROM src").collect().toSeq) + checkAnswer(ctx.sql("SELECT percentile(key, 1) FROM src LIMIT 1"), + ctx.sql("SELECT max(key) FROM src").collect().toSeq) - checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), - sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) + checkAnswer(ctx.sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), + ctx.sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) } test("Generic UDAF aggregates") { - checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), - sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) + checkAnswer(ctx.sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), + ctx.sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) - checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), - sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) + checkAnswer(ctx.sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), + ctx.sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) } test("UDFIntegerToString") { - val testData = sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") val udfName = classOf[UDFIntegerToString].getName - sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") + ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") checkAnswer( - sql("SELECT testUDFIntegerToString(i) FROM integerTable"), + ctx.sql("SELECT testUDFIntegerToString(i) FROM integerTable"), Seq(Row("1"), Row("2"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") - reset() + ctx.reset() } test("UDFToListString") { - val testData = sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = ctx.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") + ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") val errMsg = intercept[AnalysisException] { - sql("SELECT testUDFToListString(s) FROM inputTable") + ctx.sql("SELECT testUDFToListString(s) FROM inputTable") } assert(errMsg.getMessage === "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") - reset() + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") + ctx.reset() } test("UDFToListInt") { - val testData = sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = ctx.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") + ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") val errMsg = intercept[AnalysisException] { - sql("SELECT testUDFToListInt(s) FROM inputTable") + ctx.sql("SELECT testUDFToListInt(s) FROM inputTable") } assert(errMsg.getMessage === "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") - reset() + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") + ctx.reset() } test("UDFToStringIntMap") { - val testData = sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = ctx.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + + ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + s"AS '${classOf[UDFToStringIntMap].getName}'") val errMsg = intercept[AnalysisException] { - sql("SELECT testUDFToStringIntMap(s) FROM inputTable") + ctx.sql("SELECT testUDFToStringIntMap(s) FROM inputTable") } assert(errMsg.getMessage === "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") - reset() + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") + ctx.reset() } test("UDFToIntIntMap") { - val testData = sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = ctx.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + + ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + s"AS '${classOf[UDFToIntIntMap].getName}'") val errMsg = intercept[AnalysisException] { - sql("SELECT testUDFToIntIntMap(s) FROM inputTable") + ctx.sql("SELECT testUDFToIntIntMap(s) FROM inputTable") } assert(errMsg.getMessage === "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") - reset() + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") + ctx.reset() } test("UDFListListInt") { - val testData = sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() testData.registerTempTable("listListIntTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") + ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( - sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), + ctx.sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), Seq(Row(0), Row(2), Row(13))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") - reset() + ctx.reset() } test("UDFListString") { - val testData = sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() testData.registerTempTable("listStringTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") + ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( - sql("SELECT testUDFListString(l) FROM listStringTable"), + ctx.sql("SELECT testUDFListString(l) FROM listStringTable"), Seq(Row("a,b,c"), Row("d,e"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") - reset() + ctx.reset() } test("UDFStringString") { - val testData = sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") - sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") + ctx.sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), + ctx.sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") - reset() + ctx.reset() } test("UDFTwoListList") { - val testData = sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() testData.registerTempTable("TwoListTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( - sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), + ctx.sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - reset() + ctx.reset() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 563251ee4ba0..ceb190fe7e91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -17,21 +17,22 @@ package org.apache.spark.sql.hive.execution -import org.scalatest.BeforeAndAfter - /* Implicit conversions */ import scala.collection.JavaConversions._ /** * A set of test cases that validate partition and column pruning. */ -class PruningSuite extends HiveComparisonTest with BeforeAndAfter { - ctx.cacheTables = false - - // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset - // the environment to ensure all referenced tables in this suites are not cached in-memory. - // Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. - ctx.reset() +class PruningSuite extends HiveComparisonTest { + + protected override def beforeAll(): Unit = { + super.beforeAll() + ctx.cacheTables = false + // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset + // the environment to ensure all referenced tables in this suites are not cached in-memory. + // Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. + ctx.reset() + } // Column pruning tests @@ -143,7 +144,8 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { expectedScannedColumns: Seq[String], expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { - val plan = new ctx.QueryExecution(sql).executedPlan + val _ctx = ctx + val plan = new _ctx.QueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 69d4335c762f..85444a2d42e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.HiveTestUtils import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestData.TestData import org.apache.spark.sql.types._ case class Nested1(f1: Nested2) @@ -62,33 +63,31 @@ class MyDialect extends DefaultParserDialect * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest with HiveTestUtils { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ + import testImplicits._ test("UDTF") { - sql(s"ADD JAR ${getHiveFile("TestUDTF.jar").getCanonicalPath()}") + ctx.sql(s"ADD JAR ${ctx.getHiveFile("TestUDTF.jar").getCanonicalPath()}") // The function source code can be found at: // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( + ctx.sql( """ |CREATE TEMPORARY FUNCTION udtf_count2 |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' """.stripMargin) checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + ctx.sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), Row(97, 500) :: Row(97, 500) :: Nil) checkAnswer( - sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + ctx.sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), Row(3) :: Row(3) :: Nil) } test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") - val query = sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") + val query = ctx.sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") checkAnswer(query, Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) } @@ -113,7 +112,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { orders.toDF.registerTempTable("orders1") orderUpdates.toDF.registerTempTable("orderupdates1") - sql( + ctx.sql( """CREATE TABLE orders( | id INT, | make String, @@ -126,7 +125,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { |STORED AS PARQUET """.stripMargin) - sql( + ctx.sql( """CREATE TABLE orderupdates( | id INT, | make String, @@ -139,12 +138,12 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { |STORED AS PARQUET """.stripMargin) - sql("set hive.exec.dynamic.partition.mode=nonstrict") - sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") - sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") + ctx.sql("set hive.exec.dynamic.partition.mode=nonstrict") + ctx.sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") + ctx.sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") checkAnswer( - sql( + ctx.sql( """ |select orders.state, orders.month |from orders @@ -162,22 +161,22 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { val allFunctions = (FunctionRegistry.builtin.listFunction().toSet[String] ++ org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames).toList.sorted - checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) - checkAnswer(sql("SHOW functions abs"), Row("abs")) - checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) - checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) - checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) - checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) - checkAnswer(sql("SHOW functions `~`"), Row("~")) - checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) - checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) + checkAnswer(ctx.sql("SHOW functions"), allFunctions.map(Row(_))) + checkAnswer(ctx.sql("SHOW functions abs"), Row("abs")) + checkAnswer(ctx.sql("SHOW functions 'abs'"), Row("abs")) + checkAnswer(ctx.sql("SHOW functions abc.abs"), Row("abs")) + checkAnswer(ctx.sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(ctx.sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(ctx.sql("SHOW functions `~`"), Row("~")) + checkAnswer(ctx.sql("SHOW functions `a function doens't exist`"), Nil) + checkAnswer(ctx.sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) // this probably will failed if we add more function with `sha` prefixing. - checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + checkAnswer(ctx.sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) } test("describe functions") { // The Spark SQL built-in functions - checkExistence(sql("describe function extended upper"), true, + checkExistence(ctx.sql("describe function extended upper"), true, "Function: upper", "Class: org.apache.spark.sql.catalyst.expressions.Upper", "Usage: upper(str) - Returns str with all characters changed to uppercase", @@ -185,18 +184,18 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { "> SELECT upper('SparkSql')", "'SPARKSQL'") - checkExistence(sql("describe functioN Upper"), true, + checkExistence(ctx.sql("describe functioN Upper"), true, "Function: upper", "Class: org.apache.spark.sql.catalyst.expressions.Upper", "Usage: upper(str) - Returns str with all characters changed to uppercase") - checkExistence(sql("describe functioN Upper"), false, + checkExistence(ctx.sql("describe functioN Upper"), false, "Extended Usage") - checkExistence(sql("describe functioN abcadf"), true, + checkExistence(ctx.sql("describe functioN abcadf"), true, "Function: abcadf is not found.") - checkExistence(sql("describe functioN `~`"), true, + checkExistence(ctx.sql("describe functioN `~`"), true, "Function: ~", "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", "Usage: ~ n - Bitwise not") @@ -206,7 +205,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") - val query = sql( + val query = ctx.sql( """ |SELECT | MIN(c1), @@ -230,7 +229,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") - sql( + ctx.sql( """ |CREATE TABLE with_table1 AS |WITH T AS ( @@ -240,27 +239,27 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { |SELECT * |FROM T """.stripMargin) - val query = sql("SELECT * FROM with_table1") + val query = ctx.sql("SELECT * FROM with_table1") checkAnswer(query, Row(1, 1) :: Nil) } test("explode nested Field") { Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") checkAnswer( - sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), + ctx.sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), Row(1) :: Row(2) :: Row(3) :: Nil) } test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { checkAnswer( - sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), - sql("SELECT key + key as a FROM src ORDER BY a").collect().toSeq + ctx.sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), + ctx.sql("SELECT key + key as a FROM src ORDER BY a").collect().toSeq ) } test("CTAS without serde") { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { - val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + val relation = EliminateSubQueries(ctx.catalog.lookupRelation(Seq(tableName))) relation match { case LogicalRelation(r: ParquetRelation) => if (!isDataSourceParquet) { @@ -278,89 +277,89 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { } } - val originalConf = convertCTAS + val originalConf = ctx.convertCTAS - setConf(HiveContext.CONVERT_CTAS, true) + ctx.setConf(HiveContext.CONVERT_CTAS, true) try { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + ctx.sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + ctx.sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") var message = intercept[AnalysisException] { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + ctx.sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") }.getMessage assert(message.contains("ctas1 already exists")) checkRelation("ctas1", true) - sql("DROP TABLE ctas1") + ctx.sql("DROP TABLE ctas1") // Specifying database name for query can be converted to data source write path // is not allowed right now. message = intercept[AnalysisException] { - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + ctx.sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") }.getMessage assert( message.contains("Cannot specify database name in a CTAS statement"), "When spark.sql.hive.convertCTAS is true, we should not allow " + "database name specified.") - sql("CREATE TABLE ctas1 stored as textfile" + + ctx.sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", true) - sql("DROP TABLE ctas1") + ctx.sql("DROP TABLE ctas1") - sql("CREATE TABLE ctas1 stored as sequencefile" + + ctx.sql("CREATE TABLE ctas1 stored as sequencefile" + " AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", true) - sql("DROP TABLE ctas1") + ctx.sql("DROP TABLE ctas1") - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + ctx.sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", false) - sql("DROP TABLE ctas1") + ctx.sql("DROP TABLE ctas1") - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + ctx.sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", false) - sql("DROP TABLE ctas1") + ctx.sql("DROP TABLE ctas1") - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + ctx.sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", false) - sql("DROP TABLE ctas1") + ctx.sql("DROP TABLE ctas1") } finally { - setConf(HiveContext.CONVERT_CTAS, originalConf) - sql("DROP TABLE IF EXISTS ctas1") + ctx.setConf(HiveContext.CONVERT_CTAS, originalConf) + ctx.sql("DROP TABLE IF EXISTS ctas1") } } test("SQL Dialect Switching") { - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) - setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) - assert(getSQLDialect().getClass === classOf[MyDialect]) - assert(sql("SELECT 1").collect() === Array(Row(1))) + assert(ctx.getSQLDialect().getClass === classOf[HiveQLDialect]) + ctx.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) + assert(ctx.getSQLDialect().getClass === classOf[MyDialect]) + assert(ctx.sql("SELECT 1").collect() === Array(Row(1))) // set the dialect back to the DefaultSQLDialect - sql("SET spark.sql.dialect=sql") - assert(getSQLDialect().getClass === classOf[DefaultParserDialect]) - sql("SET spark.sql.dialect=hiveql") - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + ctx.sql("SET spark.sql.dialect=sql") + assert(ctx.getSQLDialect().getClass === classOf[DefaultParserDialect]) + ctx.sql("SET spark.sql.dialect=hiveql") + assert(ctx.getSQLDialect().getClass === classOf[HiveQLDialect]) // set invalid dialect - sql("SET spark.sql.dialect.abc=MyTestClass") - sql("SET spark.sql.dialect=abc") + ctx.sql("SET spark.sql.dialect.abc=MyTestClass") + ctx.sql("SET spark.sql.dialect=abc") intercept[Exception] { - sql("SELECT 1") + ctx.sql("SELECT 1") } // test if the dialect set back to HiveQLDialect - getSQLDialect().getClass === classOf[HiveQLDialect] + ctx.getSQLDialect().getClass === classOf[HiveQLDialect] - sql("SET spark.sql.dialect=MyTestClass") + ctx.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { - sql("SELECT 1") + ctx.sql("SELECT 1") } // test if the dialect set back to HiveQLDialect - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(ctx.getSQLDialect().getClass === classOf[HiveQLDialect]) } test("CTAS with serde") { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() - sql( + ctx.sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() + ctx.sql( """CREATE TABLE ctas2 | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") @@ -370,7 +369,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { | SELECT key, value | FROM src | ORDER BY key, value""".stripMargin).collect() - sql( + ctx.sql( """CREATE TABLE ctas3 | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' | STORED AS textfile AS @@ -379,41 +378,41 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { | ORDER BY key, value""".stripMargin).collect() // the table schema may like (key: integer, value: string) - sql( + ctx.sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect() // do nothing cause the table ctas4 already existed. - sql( + ctx.sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() checkAnswer( - sql("SELECT k, value FROM ctas1 ORDER BY k, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + ctx.sql("SELECT k, value FROM ctas1 ORDER BY k, value"), + ctx.sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) checkAnswer( - sql("SELECT key, value FROM ctas2 ORDER BY key, value"), - sql( + ctx.sql("SELECT key, value FROM ctas2 ORDER BY key, value"), + ctx.sql( """ SELECT key, value FROM src ORDER BY key, value""").collect().toSeq) checkAnswer( - sql("SELECT key, value FROM ctas3 ORDER BY key, value"), - sql( + ctx.sql("SELECT key, value FROM ctas3 ORDER BY key, value"), + ctx.sql( """ SELECT key, value FROM src ORDER BY key, value""").collect().toSeq) intercept[AnalysisException] { - sql( + ctx.sql( """CREATE TABLE ctas4 AS | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() } checkAnswer( - sql("SELECT key, value FROM ctas4 ORDER BY key, value"), - sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) + ctx.sql("SELECT key, value FROM ctas4 ORDER BY key, value"), + ctx.sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) - checkExistence(sql("DESC EXTENDED ctas2"), true, + checkExistence(ctx.sql("DESC EXTENDED ctas2"), true, "name:key", "type:string", "name:value", "ctas2", "org.apache.hadoop.hive.ql.io.RCFileInputFormat", "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", @@ -421,7 +420,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - sql( + ctx.sql( """CREATE TABLE ctas5 | STORED AS parquet AS | SELECT key, value @@ -429,7 +428,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { | ORDER BY key, value""".stripMargin).collect() withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { - checkExistence(sql("DESC EXTENDED ctas5"), true, + checkExistence(ctx.sql("DESC EXTENDED ctas5"), true, "name:key", "type:string", "name:value", "ctas5", "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", @@ -441,57 +440,57 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { // use the Hive SerDe for parquet tables withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { checkAnswer( - sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + ctx.sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + ctx.sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) } } test("specifying the column list for CTAS") { Seq((1, "111111"), (2, "222222")).toDF("key", "value").registerTempTable("mytable1") - sql("create table gen__tmp(a int, b string) as select key, value from mytable1") + ctx.sql("create table gen__tmp(a int, b string) as select key, value from mytable1") checkAnswer( - sql("SELECT a, b from gen__tmp"), - sql("select key, value from mytable1").collect()) - sql("DROP TABLE gen__tmp") + ctx.sql("SELECT a, b from gen__tmp"), + ctx.sql("select key, value from mytable1").collect()) + ctx.sql("DROP TABLE gen__tmp") - sql("create table gen__tmp(a double, b double) as select key, value from mytable1") + ctx.sql("create table gen__tmp(a double, b double) as select key, value from mytable1") checkAnswer( - sql("SELECT a, b from gen__tmp"), - sql("select cast(key as double), cast(value as double) from mytable1").collect()) - sql("DROP TABLE gen__tmp") + ctx.sql("SELECT a, b from gen__tmp"), + ctx.sql("select cast(key as double), cast(value as double) from mytable1").collect()) + ctx.sql("DROP TABLE gen__tmp") - sql("drop table mytable1") + ctx.sql("drop table mytable1") } test("command substitution") { - sql("set tbl=src") + ctx.sql("set tbl=src") checkAnswer( - sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"), - sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + ctx.sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"), + ctx.sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) - sql("set hive.variable.substitute=false") // disable the substitution - sql("set tbl2=src") + ctx.sql("set hive.variable.substitute=false") // disable the substitution + ctx.sql("set tbl2=src") intercept[Exception] { - sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect() + ctx.sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect() } - sql("set hive.variable.substitute=true") // enable the substitution + ctx.sql("set hive.variable.substitute=true") // enable the substitution checkAnswer( - sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"), - sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + ctx.sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"), + ctx.sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) } test("ordering not in select") { checkAnswer( - sql("SELECT key FROM src ORDER BY value"), - sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq) + ctx.sql("SELECT key FROM src ORDER BY value"), + ctx.sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq) } test("ordering not in agg") { checkAnswer( - sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), - sql(""" + ctx.sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), + ctx.sql(""" SELECT key FROM ( SELECT key, value @@ -501,103 +500,103 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { } test("double nested data") { - sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) + ctx.sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) .toDF().registerTempTable("nested") checkAnswer( - sql("SELECT f1.f2.f3 FROM nested"), + ctx.sql("SELECT f1.f2.f3 FROM nested"), Row(1)) - checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), + checkAnswer(ctx.sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), Seq.empty[Row]) checkAnswer( - sql("SELECT * FROM test_ctas_1234"), - sql("SELECT * FROM nested").collect().toSeq) + ctx.sql("SELECT * FROM test_ctas_1234"), + ctx.sql("SELECT * FROM nested").collect().toSeq) intercept[AnalysisException] { - sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() + ctx.sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() } } test("test CTAS") { - checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) + checkAnswer(ctx.sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) checkAnswer( - sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), - sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) + ctx.sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), + ctx.sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) } test("SPARK-4825 save join to table") { - val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() - sql("CREATE TABLE test1 (key INT, value STRING)") + val testData = ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() + ctx.sql("CREATE TABLE test1 (key INT, value STRING)") testData.write.mode(SaveMode.Append).insertInto("test1") - sql("CREATE TABLE test2 (key INT, value STRING)") + ctx.sql("CREATE TABLE test2 (key INT, value STRING)") testData.write.mode(SaveMode.Append).insertInto("test2") testData.write.mode(SaveMode.Append).insertInto("test2") - sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") + ctx.sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") checkAnswer( - table("test"), - sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq) + ctx.table("test"), + ctx.sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq) } test("SPARK-3708 Backticks aren't handled correctly is aliases") { checkAnswer( - sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), - sql("SELECT `key` FROM src").collect().toSeq) + ctx.sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), + ctx.sql("SELECT `key` FROM src").collect().toSeq) } test("SPARK-3834 Backticks not correctly handled in subquery aliases") { checkAnswer( - sql("SELECT a.key FROM (SELECT key FROM src) `a`"), - sql("SELECT `key` FROM src").collect().toSeq) + ctx.sql("SELECT a.key FROM (SELECT key FROM src) `a`"), + ctx.sql("SELECT `key` FROM src").collect().toSeq) } test("SPARK-3814 Support Bitwise & operator") { checkAnswer( - sql("SELECT case when 1&1=1 then 1 else 0 end FROM src"), - sql("SELECT 1 FROM src").collect().toSeq) + ctx.sql("SELECT case when 1&1=1 then 1 else 0 end FROM src"), + ctx.sql("SELECT 1 FROM src").collect().toSeq) } test("SPARK-3814 Support Bitwise | operator") { checkAnswer( - sql("SELECT case when 1|0=1 then 1 else 0 end FROM src"), - sql("SELECT 1 FROM src").collect().toSeq) + ctx.sql("SELECT case when 1|0=1 then 1 else 0 end FROM src"), + ctx.sql("SELECT 1 FROM src").collect().toSeq) } test("SPARK-3814 Support Bitwise ^ operator") { checkAnswer( - sql("SELECT case when 1^0=1 then 1 else 0 end FROM src"), - sql("SELECT 1 FROM src").collect().toSeq) + ctx.sql("SELECT case when 1^0=1 then 1 else 0 end FROM src"), + ctx.sql("SELECT 1 FROM src").collect().toSeq) } test("SPARK-3814 Support Bitwise ~ operator") { checkAnswer( - sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"), - sql("SELECT 1 FROM src").collect().toSeq) + ctx.sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"), + ctx.sql("SELECT 1 FROM src").collect().toSeq) } test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { - checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), - sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + checkAnswer(ctx.sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), + ctx.sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) } test("SPARK-2554 SumDistinct partial aggregation") { - checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), - sql("SELECT distinct key FROM src order by key").collect().toSeq) + checkAnswer(ctx.sql("SELECT sum( distinct key) FROM src group by key order by key"), + ctx.sql("SELECT distinct key FROM src order by key").collect().toSeq) } test("SPARK-4963 DataFrame sample on mutable row return wrong result") { - sql("SELECT * FROM src WHERE key % 2 = 0") + ctx.sql("SELECT * FROM src WHERE key % 2 = 0") .sample(withReplacement = false, fraction = 0.3) .registerTempTable("sampled") (1 to 10).foreach { i => checkAnswer( - sql("SELECT * FROM sampled WHERE key % 2 = 1"), + ctx.sql("SELECT * FROM sampled WHERE key % 2 = 1"), Seq.empty[Row]) } } test("SPARK-4699 HiveContext should be case insensitive by default") { checkAnswer( - sql("SELECT KEY FROM Src ORDER BY value"), - sql("SELECT key FROM src ORDER BY value").collect().toSeq) + ctx.sql("SELECT KEY FROM Src ORDER BY value"), + ctx.sql("SELECT key FROM src ORDER BY value").collect().toSeq) } test("SPARK-5284 Insert into Hive throws NPE when a inner complex type field has a null value") { @@ -609,74 +608,74 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { StructField("innerMap", MapType(StringType, IntegerType)) :: Nil), true) :: Nil) val row = Row(Row(null, null, null)) - val rowRdd = sparkContext.parallelize(row :: Nil) + val rowRdd = ctx.sparkContext.parallelize(row :: Nil) - createDataFrame(rowRdd, schema).registerTempTable("testTable") + ctx.createDataFrame(rowRdd, schema).registerTempTable("testTable") - sql( + ctx.sql( """CREATE TABLE nullValuesInInnerComplexTypes | (s struct, | innerArray:array, | innerMap: map>) """.stripMargin).collect() - sql( + ctx.sql( """ |INSERT OVERWRITE TABLE nullValuesInInnerComplexTypes |SELECT * FROM testTable """.stripMargin) checkAnswer( - sql("SELECT * FROM nullValuesInInnerComplexTypes"), + ctx.sql("SELECT * FROM nullValuesInInnerComplexTypes"), Row(Row(null, null, null)) ) - sql("DROP TABLE nullValuesInInnerComplexTypes") - dropTempTable("testTable") + ctx.sql("DROP TABLE nullValuesInInnerComplexTypes") + ctx.dropTempTable("testTable") } test("SPARK-4296 Grouping field with Hive UDF as sub expression") { - val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) - read.json(rdd).registerTempTable("data") + val rdd = ctx.sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) + ctx.read.json(rdd).registerTempTable("data") checkAnswer( - sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), + ctx.sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), Row("str-1", 1970)) - dropTempTable("data") + ctx.dropTempTable("data") - read.json(rdd).registerTempTable("data") - checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) + ctx.read.json(rdd).registerTempTable("data") + checkAnswer(ctx.sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) - dropTempTable("data") + ctx.dropTempTable("data") } test("resolve udtf in projection #1") { - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).registerTempTable("data") - val df = sql("SELECT explode(a) AS val FROM data") + val rdd = ctx.sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) + ctx.read.json(rdd).registerTempTable("data") + val df = ctx.sql("SELECT explode(a) AS val FROM data") val col = df("val") } test("resolve udtf in projection #2") { - val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) - jsonRDD(rdd).registerTempTable("data") - checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) - checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) + val rdd = ctx.sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) + ctx.jsonRDD(rdd).registerTempTable("data") + checkAnswer(ctx.sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) + checkAnswer(ctx.sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) intercept[AnalysisException] { - sql("SELECT explode(map(1, 1)) as k1 FROM data LIMIT 1") + ctx.sql("SELECT explode(map(1, 1)) as k1 FROM data LIMIT 1") } intercept[AnalysisException] { - sql("SELECT explode(map(1, 1)) as (k1, k2, k3) FROM data LIMIT 1") + ctx.sql("SELECT explode(map(1, 1)) as (k1, k2, k3) FROM data LIMIT 1") } } // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive test("TGF with non-TGF in projection") { - val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) - jsonRDD(rdd).registerTempTable("data") + val rdd = ctx.sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) + ctx.jsonRDD(rdd).registerTempTable("data") checkAnswer( - sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), + ctx.sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), Row("1", "1", "1", "1") :: Nil) } @@ -687,40 +686,40 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).registerTempTable("data") - val originalConf = convertCTAS - setConf(HiveContext.CONVERT_CTAS, false) + val rdd = ctx.sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) + ctx.read.json(rdd).registerTempTable("data") + val originalConf = ctx.convertCTAS + ctx.setConf(HiveContext.CONVERT_CTAS, false) try { - sql("CREATE TABLE explodeTest (key bigInt)") - table("explodeTest").queryExecution.analyzed match { + ctx.sql("CREATE TABLE explodeTest (key bigInt)") + ctx.table("explodeTest").queryExecution.analyzed match { case metastoreRelation: MetastoreRelation => // OK case _ => fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") } - sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") + ctx.sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") checkAnswer( - sql("SELECT key from explodeTest"), + ctx.sql("SELECT key from explodeTest"), (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) ) - sql("DROP TABLE explodeTest") - dropTempTable("data") + ctx.sql("DROP TABLE explodeTest") + ctx.dropTempTable("data") } finally { - setConf(HiveContext.CONVERT_CTAS, originalConf) + ctx.setConf(HiveContext.CONVERT_CTAS, originalConf) } } test("sanity test for SPARK-6618") { (1 to 100).par.map { i => val tableName = s"SPARK_6618_table_$i" - sql(s"CREATE TABLE $tableName (col1 string)") - catalog.lookupRelation(Seq(tableName)) - table(tableName) - tables() - sql(s"DROP TABLE $tableName") + ctx.sql(s"CREATE TABLE $tableName (col1 string)") + ctx.catalog.lookupRelation(Seq(tableName)) + ctx.table(tableName) + ctx.tables() + ctx.sql(s"DROP TABLE $tableName") } } @@ -730,7 +729,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { .select($"d1".cast(DecimalType(10, 5)).as("d")) .registerTempTable("dn") - sql("select d from dn union all select d * 2 from dn") + ctx.sql("select d from dn union all select d * 2 from dn") .queryExecution.analyzed } @@ -738,7 +737,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(100000 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") + ctx.sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") .queryExecution.toRdd.count()) } @@ -746,7 +745,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") + ctx.sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") .queryExecution.toRdd.count()) } @@ -759,10 +758,10 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { WindowData(5, "c", 9), WindowData(6, "c", 10) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + ctx.sparkContext.parallelize(data).toDF().registerTempTable("windowData") checkAnswer( - sql( + ctx.sql( """ |select area, sum(product), sum(sum(product)) over (partition by area) |from windowData group by month, area @@ -777,7 +776,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { ).map(i => Row(i._1, i._2, i._3))) checkAnswer( - sql( + ctx.sql( """ |select area, sum(product) - 1, sum(sum(product)) over (partition by area) |from windowData group by month, area @@ -792,7 +791,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { ).map(i => Row(i._1, i._2, i._3))) checkAnswer( - sql( + ctx.sql( """ |select area, sum(product), sum(product) / sum(sum(product)) over (partition by area) |from windowData group by month, area @@ -807,7 +806,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { ).map(i => Row(i._1, i._2, i._3))) checkAnswer( - sql( + ctx.sql( """ |select area, sum(product), sum(product) / sum(sum(product) - 1) over (partition by area) |from windowData group by month, area @@ -831,10 +830,10 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { WindowData(5, "c", 9), WindowData(6, "c", 10) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + ctx.sparkContext.parallelize(data).toDF().registerTempTable("windowData") checkAnswer( - sql( + ctx.sql( """ |select month, area, product, sum(product + 1) over (partition by 1 order by 2) |from windowData @@ -849,7 +848,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { ).map(i => Row(i._1, i._2, i._3, i._4))) checkAnswer( - sql( + ctx.sql( """ |select month, area, product, sum(product) |over (partition by month % 2 order by 10 - product) @@ -874,10 +873,10 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { WindowData(5, "c", 9), WindowData(6, "c", 10) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + ctx.sparkContext.parallelize(data).toDF().registerTempTable("windowData") checkAnswer( - sql( + ctx.sql( """ |select month, area, month % 2, |lag(product, 1 + 1, product) over (partition by month % 2 order by area) @@ -894,7 +893,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { } test("window function: multiple window expressions in a single expression") { - val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + val nums = ctx.sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") nums.registerTempTable("nums") val expected = @@ -909,7 +908,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { Row(1, 9, 45, 55, 25, 125) :: Row(0, 10, 55, 55, 30, 140) :: Nil - val actual = sql( + val actual = ctx.sql( """ |SELECT | y, @@ -926,18 +925,18 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { checkAnswer(actual, expected) - dropTempTable("nums") + ctx.dropTempTable("nums") } test("test case key when") { (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t") checkAnswer( - sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"), + ctx.sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"), Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil) } test("SPARK-7595: Window will cause resolve failed with self join") { - checkAnswer(sql( + checkAnswer(ctx.sql( """ |with | v1 as (select key, count(value) over (partition by key) cnt_val from src), @@ -950,27 +949,27 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { Seq(1, 2, 3).map { i => (i.toString, i.toString) }.toDF("key", "value").registerTempTable("df_analysis") - sql("SELECT kEy from df_analysis group by key").collect() - sql("SELECT kEy+3 from df_analysis group by key+3").collect() - sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() - sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect() - sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect() - sql("SELECT 2 from df_analysis A group by key+1").collect() + ctx.sql("SELECT kEy from df_analysis group by key").collect() + ctx.sql("SELECT kEy+3 from df_analysis group by key+3").collect() + ctx.sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() + ctx.sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect() + ctx.sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect() + ctx.sql("SELECT 2 from df_analysis A group by key+1").collect() intercept[AnalysisException] { - sql("SELECT kEy+1 from df_analysis group by key+3") + ctx.sql("SELECT kEy+1 from df_analysis group by key+3") } intercept[AnalysisException] { - sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)") + ctx.sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)") } } test("Cast STRING to BIGINT") { - checkAnswer(sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L)) + checkAnswer(ctx.sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L)) } // `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest test("udf_java_method") { - checkAnswer(sql( + checkAnswer(ctx.sql( """ |SELECT java_method("java.lang.String", "valueOf", 1), | java_method("java.lang.String", "isEmpty"), @@ -993,34 +992,34 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { test("dynamic partition value test") { try { - sql("set hive.exec.dynamic.partition.mode=nonstrict") + ctx.sql("set hive.exec.dynamic.partition.mode=nonstrict") // date - sql("drop table if exists dynparttest1") - sql("create table dynparttest1 (value int) partitioned by (pdate date)") - sql( + ctx.sql("drop table if exists dynparttest1") + ctx.sql("create table dynparttest1 (value int) partitioned by (pdate date)") + ctx.sql( """ |insert into table dynparttest1 partition(pdate) | select count(*), cast('2015-05-21' as date) as pdate from src """.stripMargin) checkAnswer( - sql("select * from dynparttest1"), + ctx.sql("select * from dynparttest1"), Seq(Row(500, java.sql.Date.valueOf("2015-05-21")))) // decimal - sql("drop table if exists dynparttest2") - sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))") - sql( + ctx.sql("drop table if exists dynparttest2") + ctx.sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))") + ctx.sql( """ |insert into table dynparttest2 partition(pdec) | select count(*), cast('100.12' as decimal(5, 1)) as pdec from src """.stripMargin) checkAnswer( - sql("select * from dynparttest2"), + ctx.sql("select * from dynparttest2"), Seq(Row(500, new java.math.BigDecimal("100.1")))) } finally { - sql("drop table if exists dynparttest1") - sql("drop table if exists dynparttest2") - sql("set hive.exec.dynamic.partition.mode=strict") + ctx.sql("drop table if exists dynparttest1") + ctx.sql("drop table if exists dynparttest2") + ctx.sql("set hive.exec.dynamic.partition.mode=strict") } } @@ -1029,7 +1028,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { val thread = new Thread { override def run() { // To make sure this test works, this jar should not be loaded in another place. - ctx.sql(s"ADD JAR ${getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + ctx.sql(s"ADD JAR ${ctx.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") try { ctx.sql( """ @@ -1053,14 +1052,14 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { test("SPARK-6785: HiveQuerySuite - Date comparison test 2") { checkAnswer( - sql("SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1"), + ctx.sql("SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1"), Row(false)) } test("SPARK-6785: HiveQuerySuite - Date cast") { // new Date(0) == 1970-01-01 00:00:00.0 GMT == 1969-12-31 16:00:00.0 PST checkAnswer( - sql( + ctx.sql( """ | SELECT | CAST(CAST(0 AS timestamp) AS date), @@ -1098,6 +1097,6 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") - checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + checkAnswer(ctx.sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 2a6d9d0d57bf..72457f1d2390 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.hive.test.HiveSparkPlanTest import org.apache.spark.sql.types.StringType class ScriptTransformationSuite extends HiveSparkPlanTest { - private val ctx = hiveContext private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 759d537041d0..7da1242a283f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ + import testImplicits._ override val dataSourceName: String = classOf[DefaultSource].getCanonicalName @@ -38,7 +36,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext + ctx.sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") .write @@ -49,7 +47,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.options(Map( + ctx.read.options(Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index a273467695cb..0cda8cb03115 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -25,8 +25,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.SharedHiveContext -import org.apache.spark.util.Utils +import org.apache.spark.sql.hive.test.HiveTestUtils // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -35,24 +34,16 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ +class OrcPartitionDiscoverySuite extends QueryTest with HiveTestUtils { + import testImplicits._ val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal - def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } - def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { data.toDF().write.mode("overwrite").orc(path.getCanonicalPath) } - def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { df.write.mode("overwrite").orc(path.getCanonicalPath) @@ -90,11 +81,11 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.orc(base.getCanonicalPath).registerTempTable("t") + ctx.read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( - sql("SELECT * FROM t"), + ctx.sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -102,7 +93,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, i.toString, pi, ps)) checkAnswer( - sql("SELECT intField, pi FROM t"), + ctx.sql("SELECT intField, pi FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -110,14 +101,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, pi)) checkAnswer( - sql("SELECT * FROM t WHERE pi = 1"), + ctx.sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") } yield Row(i, i.toString, 1, ps)) checkAnswer( - sql("SELECT * FROM t WHERE ps = 'foo'"), + ctx.sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -137,11 +128,11 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.orc(base.getCanonicalPath).registerTempTable("t") + ctx.read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( - sql("SELECT * FROM t"), + ctx.sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -149,7 +140,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, pi, i.toString, ps)) checkAnswer( - sql("SELECT intField, pi FROM t"), + ctx.sql("SELECT intField, pi FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -157,14 +148,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, pi)) checkAnswer( - sql("SELECT * FROM t WHERE pi = 1"), + ctx.sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") } yield Row(i, 1, i.toString, ps)) checkAnswer( - sql("SELECT * FROM t WHERE ps = 'foo'"), + ctx.sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -186,14 +177,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read + ctx.read .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { checkAnswer( - sql("SELECT * FROM t"), + ctx.sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) @@ -201,14 +192,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, i.toString, pi, ps)) checkAnswer( - sql("SELECT * FROM t WHERE pi IS NULL"), + ctx.sql("SELECT * FROM t WHERE pi IS NULL"), for { i <- 1 to 10 ps <- Seq("foo", null.asInstanceOf[String]) } yield Row(i, i.toString, null, ps)) checkAnswer( - sql("SELECT * FROM t WHERE ps IS NULL"), + ctx.sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) @@ -228,14 +219,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read + ctx.read .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { checkAnswer( - sql("SELECT * FROM t"), + ctx.sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -243,7 +234,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, pi, i.toString, ps)) checkAnswer( - sql("SELECT * FROM t WHERE ps IS NULL"), + ctx.sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, 2) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 415363793b83..31dd2ef96d52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -46,9 +46,7 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with OrcTest { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ + import testImplicits._ def getTempFilePath(prefix: String, suffix: String = ""): File = { val tempFile = File.createTempFile(prefix, suffix) @@ -70,7 +68,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { test("Read/write binary data") { withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => - val bytes = read.orc(file).head().getAs[Array[Byte]](0) + val bytes = ctx.read.orc(file).head().getAs[Array[Byte]](0) assert(new String(bytes, "utf8") === "test") } } @@ -88,16 +86,16 @@ class OrcQuerySuite extends QueryTest with OrcTest { withOrcFile(data) { file => checkAnswer( - read.orc(file), + ctx.read.orc(file), data.toDF().collect()) } } test("Creating case class RDD table") { val data = (1 to 100).map(i => (i, s"val_$i")) - sparkContext.parallelize(data).toDF().registerTempTable("t") + ctx.sparkContext.parallelize(data).toDF().registerTempTable("t") withTempTable("t") { - checkAnswer(sql("SELECT * FROM t"), data.toDF().collect()) + checkAnswer(ctx.sql("SELECT * FROM t"), data.toDF().collect()) } } @@ -110,13 +108,13 @@ class OrcQuerySuite extends QueryTest with OrcTest { // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // expr = leaf-0 - assert(sql("SELECT name FROM t WHERE age <= 5").count() === 5) + assert(ctx.sql("SELECT name FROM t WHERE age <= 5").count() === 5) // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // expr = (not leaf-0) assertResult(10) { - sql("SELECT name, contacts FROM t where age > 5") + ctx.sql("SELECT name, contacts FROM t where age > 5") .flatMap(_.getAs[Seq[_]]("contacts")) .count() } @@ -126,7 +124,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { // leaf-1 = (LESS_THAN age 8) // expr = (and (not leaf-0) leaf-1) { - val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") + val df = ctx.sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") assert(df.count() === 2) assertResult(4) { df.flatMap(_.getAs[Seq[_]]("contacts")).count() @@ -138,7 +136,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { // leaf-1 = (LESS_THAN_EQUALS age 8) // expr = (or leaf-0 (not leaf-1)) { - val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") + val df = ctx.sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") assert(df.count() === 3) assertResult(6) { df.flatMap(_.getAs[Seq[_]]("contacts")).count() @@ -158,7 +156,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { withOrcFile(data) { file => checkAnswer( - read.orc(file), + ctx.read.orc(file), Row(Seq.fill(5)(null): _*)) } } @@ -175,7 +173,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { // Following codec is supported in hive-0.13.1, ignore it now ignore("Other compression options for writing to an ORC file - 0.13.1 and above") { val data = (1 to 100).map(i => (i, s"val_$i")) - val conf = sparkContext.hadoopConfiguration + val conf = ctx.sparkContext.hadoopConfiguration conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY") withOrcFile(data) { file => @@ -202,33 +200,33 @@ class OrcQuerySuite extends QueryTest with OrcTest { test("simple select queries") { withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { checkAnswer( - sql("SELECT `_1` FROM t where t.`_1` > 5"), + ctx.sql("SELECT `_1` FROM t where t.`_1` > 5"), (6 until 10).map(Row.apply(_))) checkAnswer( - sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), + ctx.sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), (0 until 5).map(Row.apply(_))) } } test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withOrcTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + ctx.sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withOrcTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), data.map(Row.fromTuple)) + ctx.sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -239,7 +237,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { } withOrcTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") + val selfJoin = ctx.sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) @@ -254,7 +252,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { test("nested data - struct with array field") { val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) withOrcTable(data, "t") { - checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { + checkAnswer(ctx.sql("SELECT `_1`.`_2`[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) }) } @@ -263,7 +261,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { test("nested data - array of struct") { val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) withOrcTable(data, "t") { - checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { + checkAnswer(ctx.sql("SELECT `_1`[0].`_2` FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) }) } @@ -271,18 +269,18 @@ class OrcQuerySuite extends QueryTest with OrcTest { test("columns only referenced by pushed down filters should remain") { withOrcTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) + checkAnswer(ctx.sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) } } test("SPARK-5309 strings stored using dictionary compression in orc") { withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") { checkAnswer( - sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), + ctx.sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), (0 until 10).map(i => Row("same", "run_" + i, 100))) checkAnswer( - sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), + ctx.sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), List(Row("same", "run_5", 100))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 7e25aab1b9ac..40e2bc71b3f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -20,14 +20,12 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.HiveTestUtils case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with SharedHiveContext { - protected val ctx = hiveContext - import ctx.implicits._ - import ctx._ +abstract class OrcSuite extends QueryTest with HiveTestUtils { + import testImplicits._ var orcTableDir: File = null var orcTableAsDir: File = null @@ -44,13 +42,13 @@ abstract class OrcSuite extends QueryTest with SharedHiveContext { orcTableDir.delete() orcTableDir.mkdir() - sparkContext + ctx.sparkContext .makeRDD(1 to 10) .map(i => OrcData(i, s"part-$i")) .toDF() .registerTempTable(s"orc_temp_table") - sql( + ctx.sql( s"""CREATE EXTERNAL TABLE normal_orc( | intField INT, | stringField STRING @@ -59,7 +57,7 @@ abstract class OrcSuite extends QueryTest with SharedHiveContext { |LOCATION '${orcTableAsDir.getCanonicalPath}' """.stripMargin) - sql( + ctx.sql( s"""INSERT INTO TABLE normal_orc |SELECT intField, stringField FROM orc_temp_table """.stripMargin) @@ -71,66 +69,65 @@ abstract class OrcSuite extends QueryTest with SharedHiveContext { } test("create temporary orc table") { - checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + checkAnswer(ctx.sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) checkAnswer( - sql("SELECT * FROM normal_orc_source"), + ctx.sql("SELECT * FROM normal_orc_source"), (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT * FROM normal_orc_source where intField > 5"), + ctx.sql("SELECT * FROM normal_orc_source where intField > 5"), (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + ctx.sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), (1 to 10).map(i => Row(1, s"part-$i"))) } test("create temporary orc table as") { - checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + checkAnswer(ctx.sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) checkAnswer( - sql("SELECT * FROM normal_orc_source"), + ctx.sql("SELECT * FROM normal_orc_source"), (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT * FROM normal_orc_source WHERE intField > 5"), + ctx.sql("SELECT * FROM normal_orc_source WHERE intField > 5"), (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + ctx.sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), (1 to 10).map(i => Row(1, s"part-$i"))) } test("appending insert") { - sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5") + ctx.sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5") checkAnswer( - sql("SELECT * FROM normal_orc_source"), + ctx.sql("SELECT * FROM normal_orc_source"), (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => Seq.fill(2)(Row(i, s"part-$i")) }) } test("overwrite insert") { - sql( + ctx.sql( """INSERT OVERWRITE TABLE normal_orc_as_source |SELECT * FROM orc_temp_table WHERE intField > 5 """.stripMargin) checkAnswer( - sql("SELECT * FROM normal_orc_as_source"), + ctx.sql("SELECT * FROM normal_orc_as_source"), (6 to 10).map(i => Row(i, s"part-$i"))) } } class OrcSourceSuite extends OrcSuite { - import ctx._ override def beforeAll(): Unit = { super.beforeAll() - sql( + ctx.sql( s"""CREATE TEMPORARY TABLE normal_orc_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( @@ -138,7 +135,7 @@ class OrcSourceSuite extends OrcSuite { |) """.stripMargin) - sql( + ctx.sql( s"""CREATE TEMPORARY TABLE normal_orc_as_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 2c946314f209..09518696b974 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -27,9 +27,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.HiveTestUtils private[sql] trait OrcTest extends SparkFunSuite with HiveTestUtils { - private val ctx = hiveContext - import ctx.implicits._ - import ctx.sparkContext + import testImplicits._ /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` @@ -39,7 +37,7 @@ private[sql] trait OrcTest extends SparkFunSuite with HiveTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().write.orc(file.getCanonicalPath) + ctx.sparkContext.parallelize(data).toDF().write.orc(file.getCanonicalPath) f(file.getCanonicalPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index fb40cd8e7be8..6efe90d3bd85 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -53,8 +53,6 @@ case class ParquetDataWithKeyAndComplexTypes( * built in parquet support. */ class ParquetMetastoreSuite extends ParquetPartitioningTest { - private val ctx = hiveContext - import ctx._ override def beforeAll(): Unit = { super.beforeAll() @@ -66,7 +64,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { "jt", "jt_array", "test_parquet") - sql(s""" + ctx.sql(s""" create external table partitioned_parquet ( intField INT, @@ -80,7 +78,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { location '${partitionedTableDir.getCanonicalPath}' """) - sql(s""" + ctx.sql(s""" create external table partitioned_parquet_with_key ( intField INT, @@ -94,7 +92,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { location '${partitionedTableDirWithKey.getCanonicalPath}' """) - sql(s""" + ctx.sql(s""" create external table normal_parquet ( intField INT, @@ -107,7 +105,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { location '${new File(normalTableDir, "normal").getCanonicalPath}' """) - sql(s""" + ctx.sql(s""" CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes ( intField INT, @@ -123,7 +121,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { LOCATION '${partitionedTableDirWithComplexTypes.getCanonicalPath}' """) - sql(s""" + ctx.sql(s""" CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes ( intField INT, @@ -139,7 +137,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' """) - sql( + ctx.sql( """ |create table test_parquet |( @@ -153,27 +151,27 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + ctx.sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") } (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") + ctx.sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") } (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)") + ctx.sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)") } (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)") + ctx.sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)") } - val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - read.json(rdd1).registerTempTable("jt") - val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) - read.json(rdd2).registerTempTable("jt_array") + val rdd1 = ctx.sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + ctx.read.json(rdd1).registerTempTable("jt") + val rdd2 = ctx.sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) + ctx.read.json(rdd2).registerTempTable("jt_array") - setConf(HiveContext.CONVERT_METASTORE_PARQUET, true) + ctx.setConf(HiveContext.CONVERT_METASTORE_PARQUET, true) } override def afterAll(): Unit = { @@ -185,31 +183,31 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { "jt", "jt_array", "test_parquet") - setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) + ctx.setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } test(s"conversion is working") { assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + ctx.sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { case _: HiveTableScan => true }.isEmpty) assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + ctx.sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { case _: PhysicalRDD => true }.nonEmpty) } test("scan an empty parquet table") { - checkAnswer(sql("SELECT count(*) FROM test_parquet"), Row(0)) + checkAnswer(ctx.sql("SELECT count(*) FROM test_parquet"), Row(0)) } test("scan an empty parquet table with upper case") { - checkAnswer(sql("SELECT count(INTFIELD) FROM TEST_parquet"), Row(0)) + checkAnswer(ctx.sql("SELECT count(INTFIELD) FROM TEST_parquet"), Row(0)) } test("insert into an empty parquet table") { dropTables("test_insert_parquet") - sql( + ctx.sql( """ |create table test_insert_parquet |( @@ -223,21 +221,21 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // Insert into am empty table. - sql("insert into table test_insert_parquet select a, b from jt where jt.a > 5") + ctx.sql("insert into table test_insert_parquet select a, b from jt where jt.a > 5") checkAnswer( - sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField < 8"), + ctx.sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField < 8"), Row(6, "str6") :: Row(7, "str7") :: Nil ) // Insert overwrite. - sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + ctx.sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") checkAnswer( - sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + ctx.sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), Row(3, "str3") :: Row(4, "str4") :: Nil ) dropTables("test_insert_parquet") // Create it again. - sql( + ctx.sql( """ |create table test_insert_parquet |( @@ -250,15 +248,15 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) // Insert overwrite an empty table. - sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + ctx.sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") checkAnswer( - sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + ctx.sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), Row(3, "str3") :: Row(4, "str4") :: Nil ) // Insert into the table. - sql("insert into table test_insert_parquet select a, b from jt") + ctx.sql("insert into table test_insert_parquet select a, b from jt") checkAnswer( - sql(s"SELECT intField, stringField FROM test_insert_parquet"), + ctx.sql(s"SELECT intField, stringField FROM test_insert_parquet"), (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i")) ) dropTables("test_insert_parquet") @@ -266,7 +264,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("scan a parquet table created through a CTAS statement") { withTable("test_parquet_ctas") { - sql( + ctx.sql( """ |create table test_parquet_ctas ROW FORMAT |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' @@ -277,11 +275,11 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) checkAnswer( - sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), + ctx.sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), Seq(Row(1, "str1")) ) - table("test_parquet_ctas").queryExecution.optimizedPlan match { + ctx.table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(_: ParquetRelation) => // OK case _ => fail( "test_parquet_ctas should be converted to " + @@ -292,7 +290,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("MetastoreRelation in InsertIntoTable will be converted") { withTable("test_insert_parquet") { - sql( + ctx.sql( """ |create table test_insert_parquet |( @@ -304,7 +302,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + val df = ctx.sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.executedPlan match { case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + @@ -314,15 +312,15 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } checkAnswer( - sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), - sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ctx.sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + ctx.sql("SELECT a FROM jt WHERE jt.a > 5").collect() ) } } test("MetastoreRelation in InsertIntoHiveTable will be converted") { withTable("test_insert_parquet") { - sql( + ctx.sql( """ |create table test_insert_parquet |( @@ -334,7 +332,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + val df = ctx.sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.executedPlan match { case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + @@ -344,15 +342,15 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } checkAnswer( - sql("SELECT int_array FROM test_insert_parquet"), - sql("SELECT a FROM jt_array").collect() + ctx.sql("SELECT int_array FROM test_insert_parquet"), + ctx.sql("SELECT a FROM jt_array").collect() ) } } test("SPARK-6450 regression test") { withTable("ms_convert") { - sql( + ctx.sql( """CREATE TABLE IF NOT EXISTS ms_convert (key INT) |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' |STORED AS @@ -361,7 +359,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // This shouldn't throw AnalysisException - val analyzed = sql( + val analyzed = ctx.sql( """SELECT key FROM ms_convert |UNION ALL |SELECT key FROM ms_convert @@ -386,7 +384,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { withTable("nonPartitioned") { - sql( + ctx.sql( s"""CREATE TABLE nonPartitioned ( | key INT, | value STRING @@ -395,9 +393,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // First lookup fills the cache - val r1 = collectParquetRelation(table("nonPartitioned")) + val r1 = collectParquetRelation(ctx.table("nonPartitioned")) // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("nonPartitioned")) + val r2 = collectParquetRelation(ctx.table("nonPartitioned")) // They should be the same instance assert(r1 eq r2) } @@ -405,7 +403,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { withTable("partitioned") { - sql( + ctx.sql( s"""CREATE TABLE partitioned ( | key INT, | value STRING @@ -415,18 +413,19 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // First lookup fills the cache - val r1 = collectParquetRelation(table("partitioned")) + val r1 = collectParquetRelation(ctx.table("partitioned")) // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("partitioned")) + val r2 = collectParquetRelation(ctx.table("partitioned")) // They should be the same instance assert(r1 eq r2) } } test("Caching converted data source Parquet Relations") { - def checkCached(tableIdentifier: catalog.QualifiedTableName): Unit = { + val _ctx = ctx + def checkCached(tableIdentifier: _ctx.catalog.QualifiedTableName): Unit = { // Converted test_parquet should be cached. - catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { + ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // OK case other => @@ -438,7 +437,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") - sql( + ctx.sql( """ |create table test_insert_parquet |( @@ -451,18 +450,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - var tableIdentifier = catalog.QualifiedTableName("default", "test_insert_parquet") + var tableIdentifier = _ctx.catalog.QualifiedTableName("default", "test_insert_parquet") // First, make sure the converted test_parquet is not cached. - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Table lookup will make the table cached. - table("test_insert_parquet") + ctx.table("test_insert_parquet") checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. - invalidateTable("test_insert_parquet") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - sql( + ctx.invalidateTable("test_insert_parquet") + assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + ctx.sql( """ |INSERT INTO TABLE test_insert_parquet |select a, b from jt @@ -470,14 +469,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( - sql("select * from test_insert_parquet"), - sql("select a, b from jt").collect()) + ctx.sql("select * from test_insert_parquet"), + ctx.sql("select a, b from jt").collect()) // Invalidate the cache. - invalidateTable("test_insert_parquet") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + ctx.invalidateTable("test_insert_parquet") + assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Create a partitioned table. - sql( + ctx.sql( """ |create table test_parquet_partitioned_cache_test |( @@ -491,9 +490,10 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - tableIdentifier = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - sql( + tableIdentifier = _ctx.catalog.QualifiedTableName( + "default", "test_parquet_partitioned_cache_test") + assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + ctx.sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (`date`='2015-04-01') @@ -501,30 +501,30 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - sql( + assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + ctx.sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Make sure we can cache the partitioned table. - table("test_parquet_partitioned_cache_test") + ctx.table("test_parquet_partitioned_cache_test") checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( - sql("select STRINGField, `date`, intField from test_parquet_partitioned_cache_test"), - sql( + ctx.sql("select STRINGField, `date`, intField from test_parquet_partitioned_cache_test"), + ctx.sql( """ |select b, '2015-04-01', a FROM jt |UNION ALL |select b, '2015-04-02', a FROM jt """.stripMargin).collect()) - invalidateTable("test_parquet_partitioned_cache_test") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + ctx.invalidateTable("test_parquet_partitioned_cache_test") + assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } @@ -534,9 +534,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { * A suite of tests for the Parquet support through the data sources API. */ class ParquetSourceSuite extends ParquetPartitioningTest { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ + import testImplicits._ override def beforeAll(): Unit = { super.beforeAll() @@ -546,7 +544,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { "partitioned_parquet_with_key_and_complextypes", "normal_parquet") - sql( s""" + ctx.sql( s""" create temporary table partitioned_parquet USING org.apache.spark.sql.parquet OPTIONS ( @@ -554,7 +552,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { ) """) - sql( s""" + ctx.sql( s""" create temporary table partitioned_parquet_with_key USING org.apache.spark.sql.parquet OPTIONS ( @@ -562,7 +560,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { ) """) - sql( s""" + ctx.sql( s""" create temporary table normal_parquet USING org.apache.spark.sql.parquet OPTIONS ( @@ -570,7 +568,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { ) """) - sql( s""" + ctx.sql( s""" CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes USING org.apache.spark.sql.parquet OPTIONS ( @@ -578,7 +576,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { ) """) - sql( s""" + ctx.sql( s""" CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes USING org.apache.spark.sql.parquet OPTIONS ( @@ -588,29 +586,29 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } test("SPARK-6016 make sure to use the latest footers") { - sql("drop table if exists spark_6016_fix") + ctx.sql("drop table if exists spark_6016_fix") // Create a DataFrame with two partitions. So, the created table will have two parquet files. - val df1 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) + val df1 = ctx.read.json(ctx.sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") checkAnswer( - sql("select * from spark_6016_fix"), + ctx.sql("select * from spark_6016_fix"), (1 to 10).map(i => Row(i)) ) // Create a DataFrame with four partitions. So, the created table will have four parquet files. - val df2 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) + val df2 = ctx.read.json(ctx.sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) df2.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, // since the new table has four parquet files, we are trying to read new footers from two files // and then merge metadata in footers of these four (two outdated ones and two latest one), // which will cause an error. checkAnswer( - sql("select * from spark_6016_fix"), + ctx.sql("select * from spark_6016_fix"), (1 to 10).map(i => Row(i)) ) - sql("drop table spark_6016_fix") + ctx.sql("drop table spark_6016_fix") } test("SPARK-8811: compatibility with array of struct in Hive") { @@ -624,7 +622,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "true") withSQLConf(conf: _*) { - sql( + ctx.sql( s"""CREATE TABLE array_of_struct |STORED AS PARQUET LOCATION '$path' |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) @@ -639,7 +637,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } test("values in arrays and maps stored in parquet are always nullable") { - val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") + val df = ctx.createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = false) val arrayType1 = ArrayType(IntegerType, containsNull = false) val expectedSchema1 = @@ -658,10 +656,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest { StructField("m", mapType2, nullable = true) :: StructField("a", arrayType2, nullable = true) :: Nil) - assert(table("alwaysNullable").schema === expectedSchema2) + assert(ctx.table("alwaysNullable").schema === expectedSchema2) checkAnswer( - sql("SELECT m, a FROM alwaysNullable"), + ctx.sql("SELECT m, a FROM alwaysNullable"), Row(Map(2 -> 3), Seq(4, 5, 6))) } } @@ -677,7 +675,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { val df3 = df2.toDF("str", "max_int") df3.write.parquet(filePath2) - val df4 = read.parquet(filePath2) + val df4 = ctx.read.parquet(filePath2) checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) assert(df4.columns === Array("str", "max_int")) } @@ -687,9 +685,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { * A collection of tests for parquet data with various forms of partitioning. */ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ + import testImplicits._ var partitionedTableDir: File = null var normalTableDir: File = null @@ -703,13 +699,13 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { (1 to 10).foreach { p => val partDir = new File(partitionedTableDir, s"p=$p") - sparkContext.makeRDD(1 to 10) + ctx.sparkContext.makeRDD(1 to 10) .map(i => ParquetData(i, s"part-$p")) .toDF() .write.parquet(partDir.getCanonicalPath) } - sparkContext + ctx.sparkContext .makeRDD(1 to 10) .map(i => ParquetData(i, s"part-1")) .toDF() @@ -719,7 +715,7 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { (1 to 10).foreach { p => val partDir = new File(partitionedTableDirWithKey, s"p=$p") - sparkContext.makeRDD(1 to 10) + ctx.sparkContext.makeRDD(1 to 10) .map(i => ParquetDataWithKey(p, i, s"part-$p")) .toDF() .write.parquet(partDir.getCanonicalPath) @@ -729,7 +725,7 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { (1 to 10).foreach { p => val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p") - sparkContext.makeRDD(1 to 10).map { i => + ctx.sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithKeyAndComplexTypes( p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) }.toDF().write.parquet(partDir.getCanonicalPath) @@ -739,7 +735,7 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { (1 to 10).foreach { p => val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") - sparkContext.makeRDD(1 to 10).map { i => + ctx.sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) }.toDF().write.parquet(partDir.getCanonicalPath) } @@ -759,7 +755,7 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { */ def dropTables(tableNames: String*): Unit = { tableNames.foreach { name => - sql(s"DROP TABLE IF EXISTS $name") + ctx.sql(s"DROP TABLE IF EXISTS $name") } } @@ -771,19 +767,19 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { test(s"ordering of the partitioning columns $table") { checkAnswer( - sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + ctx.sql(s"SELECT p, stringField FROM $table WHERE p = 1"), Seq.fill(10)(Row(1, "part-1")) ) checkAnswer( - sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + ctx.sql(s"SELECT stringField, p FROM $table WHERE p = 1"), Seq.fill(10)(Row("part-1", 1)) ) } test(s"project the partitioning column $table") { checkAnswer( - sql(s"SELECT p, count(*) FROM $table group by p"), + ctx.sql(s"SELECT p, count(*) FROM $table group by p"), Row(1, 10) :: Row(2, 10) :: Row(3, 10) :: @@ -799,7 +795,7 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { test(s"project partitioning and non-partitioning columns $table") { checkAnswer( - sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + ctx.sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), Row("part-1", 1, 10) :: Row("part-2", 2, 10) :: Row("part-3", 3, 10) :: @@ -815,44 +811,44 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { test(s"simple count $table") { checkAnswer( - sql(s"SELECT COUNT(*) FROM $table"), + ctx.sql(s"SELECT COUNT(*) FROM $table"), Row(100)) } test(s"pruned count $table") { checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + ctx.sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), Row(10)) } test(s"non-existent partition $table") { checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + ctx.sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), Row(0)) } test(s"multi-partition pruned count $table") { checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + ctx.sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), Row(30)) } test(s"non-partition predicates $table") { checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + ctx.sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), Row(30)) } test(s"sum $table") { checkAnswer( - sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + ctx.sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), Row(1 + 2 + 3)) } test(s"hive udfs $table") { checkAnswer( - sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { + ctx.sql(s"SELECT concat(stringField, stringField) FROM $table"), + ctx.sql(s"SELECT stringField FROM $table").map { case Row(s: String) => Row(s + s) }.collect().toSeq) } @@ -864,7 +860,7 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { test(s"SPARK-5775 read struct from $table") { checkAnswer( - sql( + ctx.sql( s""" |SELECT p, structField.intStructField, structField.stringStructField |FROM $table WHERE p = 1 @@ -875,7 +871,7 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { // Re-enable this after SPARK-5508 is fixed ignore(s"SPARK-5775 read array from $table") { checkAnswer( - sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), + ctx.sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), (1 to 10).map(i => Row(1 to i, 1))) } } @@ -883,7 +879,7 @@ abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { test("non-part select(*)") { checkAnswer( - sql("SELECT COUNT(*) FROM normal_parquet"), + ctx.sql("SELECT COUNT(*) FROM normal_parquet"), Row(10)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index d7ebafc3b01e..8bd7dba84ec3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -23,8 +23,6 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.hive.test.HiveTestUtils class CommitFailureTestRelationSuite extends SparkFunSuite with HiveTestUtils { - private val ctx = hiveContext - // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 830591ed54b8..58b0f7f1ba89 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -28,9 +28,7 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - private val ctx = hiveContext - import ctx.implicits._ - import ctx._ + import testImplicits._ override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName @@ -42,7 +40,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext + ctx.sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") .write.parquet(partitionDir.toString) @@ -52,7 +50,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -70,7 +68,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { .format("parquet") .save(s"${dir.getCanonicalPath}/_temporary") - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + checkAnswer(ctx.read.format("parquet").load(dir.getCanonicalPath), df.collect()) } } @@ -98,7 +96,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // This shouldn't throw anything. df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(read.format("parquet").load(path), df) + checkAnswer(ctx.read.format("parquet").load(path), df) } } @@ -108,7 +106,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // Parquet doesn't allow field names with spaces. Here we are intentionally making an // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger // the bug. Please refer to spark-8079 for more details. - range(1, 10) + ctx.range(1, 10) .withColumnRenamed("id", "a b") .write .format("parquet") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 5c8beea3a975..f6e47814d8cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -23,8 +23,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { - private val ctx = hiveContext - import ctx._ override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName @@ -36,7 +34,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext + ctx.sparkContext .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") .saveAsTextFile(partitionDir.toString) } @@ -45,7 +43,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index dee39d8f7888..8079a82670b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -33,9 +33,7 @@ import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with HiveTestUtils { - private val ctx = hiveContext - import ctx.implicits._ - import ctx.sql + import testImplicits._ val dataSourceName: String @@ -90,7 +88,7 @@ abstract class HadoopFsRelationTest extends QueryTest with HiveTestUtils { df.registerTempTable("t") withTempTable("t") { checkAnswer( - sql( + ctx.sql( """SELECT l.a, r.b, l.p1, r.p2 |FROM t l JOIN t r |ON l.a = r.a AND l.p1 = r.p1 AND l.p2 = r.p2 From c4a22bcb8d26362f741c9b24f31552694b3d4f64 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 10 Aug 2015 20:39:14 -0700 Subject: [PATCH 13/39] Fix compile after resolving merge conflicts --- .../apache/spark/sql/DataFrameStatSuite.scala | 10 ++++------ .../apache/spark/sql/JsonFunctionsSuite.scala | 6 +++--- .../spark/sql/execution/PlannerSuite.scala | 4 ++++ .../execution/RowFormatConvertersSuite.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 8 ++------ .../execution/datasources/json/JsonSuite.scala | 17 +++++++++-------- .../sql/execution/metric/SQLMetricsSuite.scala | 8 ++++---- .../sql/execution/ui/SQLListenerSuite.scala | 14 +++++++------- .../apache/spark/sql/sources/InsertSuite.scala | 4 ++-- .../sql/sources/PartitionedWriteSuite.scala | 14 +++++++------- 10 files changed, 43 insertions(+), 44 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index d1d1a14355a8..5385f1e0dc64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql import java.util.Random -import org.scalatest.Matchers._ - import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SQLTestUtils @@ -31,7 +29,7 @@ class DataFrameStatSuite extends QueryTest with SQLTestUtils { test("sample with replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = true, 0.05, seed = 13), Seq(5, 10, 52, 73).map(Row(_)) @@ -40,7 +38,7 @@ class DataFrameStatSuite extends QueryTest with SQLTestUtils { test("sample without replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), Seq(16, 23, 88, 100).map(Row(_)) @@ -49,7 +47,7 @@ class DataFrameStatSuite extends QueryTest with SQLTestUtils { test("randomSplit") { val n = 600 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -166,7 +164,7 @@ class DataFrameStatSuite extends QueryTest with SQLTestUtils { } test("Frequent Items 2") { - val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4) + val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4) // this is a regression test, where when merging partitions, we omitted values with higher // counts than those that existed in the map when the map was full. This test should also fail // if anything like SPARK-9614 is observed once again diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 71c26a6f8d36..dab4b9c65295 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -class JsonFunctionsSuite extends QueryTest { +import org.apache.spark.sql.test.SQLTestUtils - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class JsonFunctionsSuite extends QueryTest with SQLTestUtils { + import testImplicits._ test("function get_json_object") { val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index fdc0e27c007c..31fde3a03173 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -37,6 +37,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val _ctx = ctx + import _ctx.planner._ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = plannedOption.getOrElse( @@ -51,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } test("unions are collapsed") { + val _ctx = ctx + import _ctx.planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index b71a66b22bd8..b9773a1e7bfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -92,7 +92,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(TestSQLContext) + SparkPlan.currentContext.set(ctx) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index eb4b27df2f6a..11a242bf9402 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -39,7 +39,7 @@ import org.apache.spark.unsafe.types.UTF8String class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers - with SharedSQLContext { + with SQLTestUtils { import UnsafeFixedWidthAggregationMap._ @@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite } testWithMemoryLeakDetection("test external sorting with an empty map") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, @@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite } testWithMemoryLeakDetection("test external sorting with empty records") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext // Memory consumption in the beginning of the task. val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index fe1c9dee01e2..8669db2e9ae6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1056,7 +1056,7 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { Some(singleRow), 1.0, Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(çtx) + None, None)(ctx) val logicalRelation3 = LogicalRelation(relation3) assert(relation0 !== relation1) @@ -1081,14 +1081,14 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = ResolvedDataSource( - context, + ctx, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, options = Map("path" -> path)) val d2 = ResolvedDataSource( - context, + ctx, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, @@ -1154,11 +1154,12 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { "abd") ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") - checkAnswer( - sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) - checkAnswer( - sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) - checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) + checkAnswer(ctx.sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) + checkAnswer(ctx.sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) + checkAnswer(ctx.sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 953284c98b20..4cc8638e232a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -25,14 +25,14 @@ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils -class SQLMetricsSuite extends SparkFunSuite { +class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("LongSQLMetric should not box Long") { - val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") + val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long") val f = () => { l += 1L } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() @@ -43,7 +43,7 @@ class SQLMetricsSuite extends SparkFunSuite { test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. - val l = TestSQLContext.sparkContext.accumulator(0L) + val l = ctx.sparkContext.accumulator(0L) val f = () => { l += 1L } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 41dd1896c15d..71fa58f2b95c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -25,12 +25,12 @@ import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SQLTestUtils -class SQLListenerSuite extends SparkFunSuite { +class SQLListenerSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ private def createTestDataFrame: DataFrame = { - import TestSQLContext.implicits._ Seq( (1, 1), (2, 2) @@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("basic") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 2a551458d948..cf9564c72542 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -193,7 +193,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { test("Caching") { // write something to the jsonTable - sql( + caseInsensitiveContext.sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt """.stripMargin) @@ -230,7 +230,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt """.stripMargin) // jsonTable should be recached. - assertCached(sql("SELECT * FROM jsonTable")) + assertCached(caseInsensitiveContext.sql("SELECT * FROM jsonTable")) // TODO we need to invalidate the cached data in InsertIntoHadoopFsRelation // // The cached data is the new data. // checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index c86ddd7c83e5..47f7d1605002 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -19,21 +19,21 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils -class PartitionedWriteSuite extends QueryTest { - import TestSQLContext.implicits._ +class PartitionedWriteSuite extends QueryTest with SQLTestUtils { + import testImplicits._ test("write many partitions") { val path = Utils.createTempDir() path.delete() - val df = TestSQLContext.range(100).select($"id", lit(1).as("data")) + val df = ctx.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - TestSQLContext.read.load(path.getCanonicalPath), + ctx.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest { val path = Utils.createTempDir() path.delete() - val base = TestSQLContext.range(100) + val base = ctx.range(100) val df = base.unionAll(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - TestSQLContext.read.load(path.getCanonicalPath), + ctx.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) From f5619f8201f8dff2b335d42e9fa1b303735de4a8 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 13:08:42 -0700 Subject: [PATCH 14/39] Fix test compile after resolving merge conflicts --- .../ParquetProtobufCompatibilitySuite.scala | 4 +- .../sql/execution/joins/InnerJoinSuite.scala | 113 +++++++++--------- .../sql/execution/joins/OuterJoinSuite.scala | 4 +- .../spark/sql/sources/DataSourceTest.scala | 7 +- .../sql/hive/test/HiveDataSourceTest.scala | 28 +++++ .../test/HiveParquetCompatibilityTest.scala | 2 +- .../spark/sql/hive/test/HiveParquetTest.scala | 2 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 29 ++--- .../spark/sql/hive/HiveParquetSuite.scala | 7 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../execution/AggregationQuerySuite.scala | 20 ++-- .../sql/hive/execution/SQLQuerySuite.scala | 18 +-- .../SimpleTextHadoopFsRelationSuite.scala | 6 +- 13 files changed, 134 insertions(+), 108 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveDataSourceTest.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala index 981334cf771c..7cab6a415ac4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest { - override def sqlContext: SQLContext = TestSQLContext private def readParquetProtobufFile(name: String): DataFrame = { val url = Thread.currentThread().getContextClassLoader.getResource(name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index ddff7cebcc17..f0c96eab0f5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame} -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution._ +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { +class InnerJoinSuite extends SparkPlanTest { private def testInnerJoin( testName: String, @@ -107,23 +106,25 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { } { - val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(1, "A"), - Row(2, "B"), - Row(3, "C"), - Row(4, "D"), - Row(5, "E"), - Row(6, "F"), - Row(null, "G") - )), new StructType().add("N", IntegerType).add("L", StringType)) - - val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(1, "a"), - Row(2, "b"), - Row(3, "c"), - Row(4, "d"), - Row(null, "e") - )), new StructType().add("n", IntegerType).add("l", StringType)) + lazy val upperCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + lazy val lowerCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) testInnerJoin( "inner join, one match per row", @@ -139,42 +140,44 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { ) } - private val testData2 = Seq( - (1, 1), - (1, 2), - (2, 1), - (2, 2), - (3, 1), - (3, 2) - ).toDF("a", "b") - { - val left = testData2.where("a = 1") - val right = testData2.where("a = 1") - testInnerJoin( - "inner join, multiple matches", - left, - right, - (left.col("a") === right.col("a")).expr, - Seq( - (1, 1, 1, 1), - (1, 1, 1, 2), - (1, 2, 1, 1), - (1, 2, 1, 2) + lazy val testData2 = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + { + lazy val left = testData2.where("a = 1") + lazy val right = testData2.where("a = 1") + testInnerJoin( + "inner join, multiple matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) + ) ) - ) - } + } - { - val left = testData2.where("a = 1") - val right = testData2.where("a = 2") - testInnerJoin( - "inner join, no matches", - left, - right, - (left.col("a") === right.col("a")).expr, - Seq.empty - ) + { + lazy val left = testData2.where("a = 1") + lazy val right = testData2.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq.empty + ) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index c0b9078f2fa6..9ae9b8020a3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest} import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} @@ -53,7 +53,7 @@ class OuterJoinSuite extends SparkPlanTest { Row(null, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) - private val condition = { + private lazy val condition = { And((left.col("a") === right.col("c")).expr, LessThan(left.col("b").expr, right.col("d").expr)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 2da4c06f9a9c..0b56f9a809f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -23,11 +23,14 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils -abstract class DataSourceTest extends QueryTest with BeforeAndAfter with SQLTestUtils { +private[sql] abstract class DataSourceTest extends AbstractDataSourceTest with SQLTestUtils + +private[sql] abstract class AbstractDataSourceTest extends QueryTest with BeforeAndAfter { + protected def _sqlContext: SQLContext // We want to test some edge cases. protected lazy val caseInsensitiveContext: SQLContext = { - val ctx = new SQLContext(sqlContext.sparkContext) + val ctx = new SQLContext(_sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveDataSourceTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveDataSourceTest.scala new file mode 100644 index 000000000000..e011d86c69ae --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveDataSourceTest.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.apache.spark.sql.sources.AbstractDataSourceTest + + +/** + * An equivalent of [[org.apache.spark.sql.sources.DataSourceTest]], but for hive tests. + */ +private[hive] abstract class HiveDataSourceTest + extends AbstractDataSourceTest + with HiveTestUtils diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala index afcd0c97d11c..468e712f64a1 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.test -import org.apache.spark.sql.parquet.AbstractParquetCompatibilityTest +import org.apache.spark.sql.execution.datasources.parquet.AbstractParquetCompatibilityTest /** * Helper class for testing Parquet compatibility in hive. diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala index b909a9c88829..9f053b4b855a 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.test -import org.apache.spark.sql.parquet.AbstractParquetTest +import org.apache.spark.sql.execution.datasources.parquet.AbstractParquetTest /** * Helper trait for Parquet tests analogous to [[org.apache.spark.sql.parquet.ParquetTest]]. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 015c8ce42638..bba54329c5b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -22,8 +22,7 @@ import java.io.File import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} -import org.apache.spark.sql.hive.test.HiveTestUtils -import org.apache.spark.sql.sources.DataSourceTest +import org.apache.spark.sql.hive.test.{HiveDataSourceTest, HiveTestUtils} import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType @@ -51,8 +50,10 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with HiveTestUtils with Lo } } -class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with HiveTestUtils { - private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) +class DataSourceWithHiveMetastoreCatalogSuite extends HiveDataSourceTest { + import testImplicits._ + + private lazy val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) Seq( "parquet" -> ( @@ -75,7 +76,7 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with HiveTe .format(provider) .saveAsTable("t") - val hiveTable = catalog.client.getTable("default", "t") + val hiveTable = ctx.catalog.client.getTable("default", "t") assert(hiveTable.inputFormat === Some(inputFormat)) assert(hiveTable.outputFormat === Some(outputFormat)) assert(hiveTable.serde === Some(serde)) @@ -87,8 +88,8 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with HiveTe assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.hiveType) === Seq("int", "string")) - checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + checkAnswer(ctx.table("t"), testDF) + assert(ctx.runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) } } @@ -104,7 +105,7 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with HiveTe .option("path", path.toString) .saveAsTable("t") - val hiveTable = catalog.client.getTable("default", "t") + val hiveTable = ctx.catalog.client.getTable("default", "t") assert(hiveTable.inputFormat === Some(inputFormat)) assert(hiveTable.outputFormat === Some(outputFormat)) assert(hiveTable.serde === Some(serde)) @@ -116,8 +117,8 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with HiveTe assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.hiveType) === Seq("int", "string")) - checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + checkAnswer(ctx.table("t"), testDF) + assert(ctx.runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) } } } @@ -127,13 +128,13 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with HiveTe withTable("t") { val path = dir.getCanonicalPath - sql( + ctx.sql( s"""CREATE TABLE t USING $provider |OPTIONS (path '$path') |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) - val hiveTable = catalog.client.getTable("default", "t") + val hiveTable = ctx.catalog.client.getTable("default", "t") assert(hiveTable.inputFormat === Some(inputFormat)) assert(hiveTable.outputFormat === Some(outputFormat)) assert(hiveTable.serde === Some(serde)) @@ -146,8 +147,8 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with HiveTe assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.hiveType) === Seq("int", "string")) - checkAnswer(table("t"), Row(1, "val_1")) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + checkAnswer(ctx.table("t"), Row(1, "val_1")) + assert(ctx.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 5b0c961aa986..f763d007a10a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,13 +17,8 @@ package org.apache.spark.sql.hive -<<<<<<< HEAD -import org.apache.spark.sql.hive.test.HiveParquetTest -======= -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.execution.datasources.parquet.ParquetTest ->>>>>>> 3c9802d9400bea802984456683b2736a450ee17e import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.HiveParquetTest case class Cases(lower: String, UPPER: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index cbfc861d670b..a0a929adb4b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -475,7 +475,7 @@ class MetastoreDataSourcesSuite // Drop table will also delete the data. ctx.sql("DROP TABLE savedJsonTable") intercept[IOException] { - ctx.read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + ctx.read.json(ctx.catalog.hiveDefaultTableFilePath("savedJsonTable")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index aeaf3df1f12c..0a79d33b1011 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -69,8 +69,8 @@ abstract class AggregationQuerySuite extends QueryTest with HiveTestUtils { emptyDF.registerTempTable("emptyTable") // Register UDAFs - ctx.udaf.register("mydoublesum", new MyDoubleSum) - ctx.udaf.register("mydoubleavg", new MyDoubleAvg) + ctx.udf.register("mydoublesum", new MyDoubleSum) + ctx.udf.register("mydoubleavg", new MyDoubleAvg) } override def afterAll(): Unit = { @@ -139,7 +139,7 @@ abstract class AggregationQuerySuite extends QueryTest with HiveTestUtils { test("null literal") { checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | AVG(null), @@ -409,7 +409,7 @@ abstract class AggregationQuerySuite extends QueryTest with HiveTestUtils { Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) checkAnswer( - sqlContext.sql( + ctx.sql( """ |SELECT | count(value1), @@ -566,26 +566,26 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue var originalUnsafeEnabled: Boolean = _ override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + originalUnsafeEnabled = ctx.conf.unsafeEnabled + ctx.setConf(SQLConf.UNSAFE_ENABLED.key, "true") super.beforeAll() } override def afterAll(): Unit = { super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") + ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + ctx.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") } override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => - sqlContext.setConf( + ctx.setConf( "spark.sql.TungstenAggregate.testFallbackStartsAt", fallbackStartsAt.toString) // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. - val newActual = DataFrame(sqlContext, actual.logicalPlan) + val newActual = DataFrame(ctx, actual.logicalPlan) QueryTest.checkAnswer(newActual, expectedAnswer) match { case Some(errorMessage) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 95025afa4896..14f25db6c9fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -755,7 +755,7 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { val data = (1 to 5).map { i => (i, i) } data.toDF("key", "value").registerTempTable("test") checkAnswer( - sql("""FROM + ctx.sql("""FROM |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t |SELECT thing1 + 1 """.stripMargin), (2 to 6).map(i => Row(i))) @@ -1115,22 +1115,22 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { } test("Convert hive interval term into Literal of CalendarIntervalType") { - checkAnswer(sql("select interval '10-9' year to month"), + checkAnswer(ctx.sql("select interval '10-9' year to month"), Row(CalendarInterval.fromString("interval 10 years 9 months"))) - checkAnswer(sql("select interval '20 15:40:32.99899999' day to second"), + checkAnswer(ctx.sql("select interval '20 15:40:32.99899999' day to second"), Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes " + "32 seconds 99 milliseconds 899 microseconds"))) - checkAnswer(sql("select interval '30' year"), + checkAnswer(ctx.sql("select interval '30' year"), Row(CalendarInterval.fromString("interval 30 years"))) - checkAnswer(sql("select interval '25' month"), + checkAnswer(ctx.sql("select interval '25' month"), Row(CalendarInterval.fromString("interval 25 months"))) - checkAnswer(sql("select interval '-100' day"), + checkAnswer(ctx.sql("select interval '-100' day"), Row(CalendarInterval.fromString("interval -14 weeks -2 days"))) - checkAnswer(sql("select interval '40' hour"), + checkAnswer(ctx.sql("select interval '40' hour"), Row(CalendarInterval.fromString("interval 1 days 16 hours"))) - checkAnswer(sql("select interval '80' minute"), + checkAnswer(ctx.sql("select interval '80' minute"), Row(CalendarInterval.fromString("interval 1 hour 20 minutes"))) - checkAnswer(sql("select interval '299.889987299' second"), + checkAnswer(ctx.sql("select interval '299.889987299' second"), Row(CalendarInterval.fromString( "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index be825f1e1b86..fcbeec366d12 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -54,8 +54,6 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource].getCanonicalName - import sqlContext._ - test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) @@ -64,7 +62,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext + ctx.sparkContext .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") .saveAsTextFile(partitionDir.toString) } @@ -73,7 +71,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } From c51b3d87e96c6c6652586220701391ca2cc3a1b1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 14:07:19 -0700 Subject: [PATCH 15/39] Fix OuterJoinSuite This suite was using the SQLContext extensively in the constructor of the test suite. The fact that we don't have the singleton anymore means this is no longer possible. This commit refactors the suite to never reference a SQLContext outside of a test body. --- .../sql/execution/joins/OuterJoinSuite.scala | 56 ++++++++----------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 9ae9b8020a3b..7d2d3449982c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} class OuterJoinSuite extends SparkPlanTest { @@ -58,17 +58,23 @@ class OuterJoinSuite extends SparkPlanTest { LessThan(left.col("b").expr, right.col("d").expr)) } + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable private def testOuterJoin( testName: String, - leftRows: DataFrame, - rightRows: DataFrame, + leftRows: => DataFrame, + rightRows: => DataFrame, joinType: JoinType, - condition: Expression, + condition: => Expression, expectedAnswer: Seq[Product]): Unit = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - test(s"$testName using ShuffledHashOuterJoin") { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(sqlContext).apply( @@ -76,19 +82,23 @@ class OuterJoinSuite extends SparkPlanTest { expectedAnswer.map(Row.fromTuple), sortAnswers = true) } - } + } + } - if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } - } + } + } - test(s"$testName using SortMergeOuterJoin") { + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(sqlContext).apply( @@ -96,25 +106,7 @@ class OuterJoinSuite extends SparkPlanTest { expectedAnswer.map(Row.fromTuple), sortAnswers = false) } - } } - } - - test(s"$testName using BroadcastNestedLoopJoin (build=left)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using BroadcastNestedLoopJoin (build=right)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) } } } From 1400770941263d586f94ef59f40199a789579b4a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 14:16:48 -0700 Subject: [PATCH 16/39] Fix style --- .../sql/execution/datasources/json/JsonSuite.scala | 12 ++++++++---- .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 ++++++--- .../org/apache/spark/sql/sources/InsertSuite.scala | 7 ++++--- .../apache/spark/sql/sources/TableScanSuite.scala | 7 ++++--- .../apache/spark/sql/hive/QueryPartitionSuite.scala | 1 - .../spark/sql/hive/execution/HiveExplainSuite.scala | 4 +++- .../spark/sql/hive/execution/HiveQuerySuite.scala | 9 ++++++--- .../spark/sql/hive/execution/HiveUDFSuite.scala | 6 ++++-- .../spark/sql/hive/execution/PruningSuite.scala | 6 +++--- .../spark/sql/hive/execution/SQLQuerySuite.scala | 10 ++++++---- 10 files changed, 44 insertions(+), 27 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8669db2e9ae6..5385cee0c461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -304,7 +304,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( - ctx.sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), + ctx.sql( + "select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), Row(new java.math.BigDecimal("922337203685477580700"), new java.math.BigDecimal("-922337203685477580800"), null) ) @@ -355,7 +356,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { // Access elements of an array field of a struct. checkAnswer( - ctx.sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), + ctx.sql( + "select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), Row(5, null) ) } @@ -590,7 +592,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - _sqlContext.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + _sqlContext.read.schema(schema).json(path) + .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) assert(relationWithSchema.schema === schema) @@ -1021,7 +1024,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { // Access an array field of a struct. checkAnswer( - ctx.sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), + ctx.sql( + "select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), Row(Seq(4, 5, 6), Seq("str1", "str2")) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d1120b2c910d..e249f9b89f01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -185,11 +185,13 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SQLTestUtils { } test("SELECT * WHERE (quoted strings)") { - assert(ctx.sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) + assert( + ctx.sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) } test("SELECT first field") { - val names = ctx.sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) + val names = + ctx.sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size === 3) assert(names(0).equals("fred")) assert(names(1).equals("joe 'foo' \"bar\"")) @@ -197,7 +199,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SQLTestUtils { } test("SELECT first field when fetchSize is two") { - val names = ctx.sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) + val names = + ctx.sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size === 3) assert(names(0).equals("fred")) assert(names(1).equals("joe 'foo' \"bar\"")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index cf9564c72542..d0c533cc4f59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -219,9 +219,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { caseInsensitiveContext.sql("SELECT a * 2 FROM jsonTable"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(caseInsensitiveContext.sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) - checkAnswer( - caseInsensitiveContext.sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + assertCached(caseInsensitiveContext.sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + checkAnswer(caseInsensitiveContext.sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Insert overwrite and keep the same schema. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 5c617b624d82..7b15d3fa3d74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -301,9 +301,10 @@ class TableScanSuite extends DataSourceTest { caseInsensitiveContext.sql("SELECT i * 2 FROM oneToTen"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(caseInsensitiveContext.sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) - checkAnswer( - caseInsensitiveContext.sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + assertCached(caseInsensitiveContext.sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer(caseInsensitiveContext.sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 5b9bbddd0dc0..910ec4654cf8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 860499ecb846..1c9941949a9e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -47,7 +47,9 @@ class HiveExplainSuite extends QueryTest with HiveTestUtils { "Limit", "src") - checkExistence(ctx.sql("explain extended create table temp__b as select * from src limit 2"), true, + checkExistence(ctx.sql( + "explain extended create table temp__b as select * from src limit 2"), + true, "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index d83e8b7a2275..f3a902a87e15 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -726,7 +726,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { .size == 5) assert( - ctx.sql("select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") + ctx.sql( + "select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") .collect() .size == 5) } @@ -773,7 +774,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } test("DESCRIBE commands") { - ctx.sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") + ctx.sql( + s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") ctx.sql( """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') @@ -935,7 +937,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { ignore("Dynamic partition folder layout") { ctx.sql("DROP TABLE IF EXISTS dynamic_part_table") - ctx.sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") + ctx.sql( + "CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") ctx.sql("SET hive.exec.dynamic.partition.mode=nonstrict") val data = Map( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 9592b755ca0a..c4384dbcc7d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -176,7 +176,8 @@ class HiveUDFSuite extends QueryTest with HiveTestUtils { val testData = ctx.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") - ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") + ctx.sql( + s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") val errMsg = intercept[AnalysisException] { ctx.sql("SELECT testUDFToListString(s) FROM inputTable") } @@ -270,7 +271,8 @@ class HiveUDFSuite extends QueryTest with HiveTestUtils { StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") - ctx.sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") + ctx.sql( + s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( ctx.sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index ceb190fe7e91..34ea48a18310 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -28,9 +28,9 @@ class PruningSuite extends HiveComparisonTest { protected override def beforeAll(): Unit = { super.beforeAll() ctx.cacheTables = false - // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset - // the environment to ensure all referenced tables in this suites are not cached in-memory. - // Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. + // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need + // to reset the environment to ensure all referenced tables in this suites are not cached + // in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. ctx.reset() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 14f25db6c9fe..3cdb4d0218b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -312,15 +312,16 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { checkRelation("ctas1", true) ctx.sql("DROP TABLE ctas1") - ctx.sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + ctx.sql( + "CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", false) ctx.sql("DROP TABLE ctas1") ctx.sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", false) ctx.sql("DROP TABLE ctas1") - - ctx.sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + ctx.sql( + "CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", false) ctx.sql("DROP TABLE ctas1") } finally { @@ -636,7 +637,8 @@ class SQLQuerySuite extends QueryTest with HiveTestUtils { } test("SPARK-4296 Grouping field with Hive UDF as sub expression") { - val rdd = ctx.sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) + val rdd = ctx.sparkContext.makeRDD( + """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) ctx.read.json(rdd).registerTempTable("data") checkAnswer( ctx.sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), From c92a3b04883cfca1c3e0b71048aba9471c9fde26 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 15:04:03 -0700 Subject: [PATCH 17/39] Fix MiMa --- project/MimaExcludes.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 90261ca3d61a..a540eddac1d8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -178,6 +178,16 @@ object MimaExcludes { // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9580: Remove SQL test singletons + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext$") ) case v if v.startsWith("1.4") => From 4debedfe75aba02e396ff6cc66592645075b869d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 16:43:29 -0700 Subject: [PATCH 18/39] Clean up inheritance in test util traits Per discussion with @JoshRosen, SharedSQLContext now extends SQLTestUtils instead of the other way round. This actually allows us to get rid of a layer of abstraction needed for hive tests. --- .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../spark/sql/ColumnExpressionSuite.scala | 4 +- .../spark/sql/DataFrameAggregateSuite.scala | 4 +- .../spark/sql/DataFrameFunctionsSuite.scala | 4 +- .../spark/sql/DataFrameImplicitsSuite.scala | 4 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 4 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 4 +- .../apache/spark/sql/DataFrameStatSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../spark/sql/DataFrameTungstenSuite.scala | 4 +- .../apache/spark/sql/DateFunctionsSuite.scala | 4 +- .../org/apache/spark/sql/JoinSuite.scala | 4 +- .../apache/spark/sql/JsonFunctionsSuite.scala | 4 +- .../apache/spark/sql/ListTablesSuite.scala | 4 +- .../spark/sql/MathExpressionsSuite.scala | 4 +- .../scala/org/apache/spark/sql/RowSuite.scala | 4 +- .../org/apache/spark/sql/SQLConfSuite.scala | 4 +- .../apache/spark/sql/SQLContextSuite.scala | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +- .../sql/ScalaReflectionRelationSuite.scala | 4 +- .../spark/sql/StringFunctionsSuite.scala | 4 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 4 +- .../spark/sql/UserDefinedTypeSuite.scala | 4 +- .../columnar/InMemoryColumnarQuerySuite.scala | 4 +- .../columnar/PartitionBatchPruningSuite.scala | 4 +- .../spark/sql/execution/ExchangeSuite.scala | 3 +- .../spark/sql/execution/PlannerSuite.scala | 4 +- .../execution/RowFormatConvertersSuite.scala | 3 +- .../spark/sql/execution/SortSuite.scala | 3 +- .../spark/sql/execution/SparkPlanTest.scala | 8 +--- .../sql/execution/TungstenSortSuite.scala | 3 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 4 +- .../UnsafeKVExternalSorterSuite.scala | 4 +- .../datasources/json/JsonSuite.scala | 6 +-- .../ParquetAvroCompatibilitySuite.scala | 3 +- .../parquet/ParquetCompatibilityTest.scala | 14 +----- .../parquet/ParquetFilterSuite.scala | 5 +- .../datasources/parquet/ParquetIOSuite.scala | 3 +- .../ParquetPartitionDiscoverySuite.scala | 5 +- .../ParquetProtobufCompatibilitySuite.scala | 3 +- .../parquet/ParquetQuerySuite.scala | 5 +- .../parquet/ParquetSchemaSuite.scala | 4 +- .../datasources/parquet/ParquetTest.scala | 10 +--- .../ParquetThriftCompatibilitySuite.scala | 3 +- .../sql/execution/debug/DebuggingSuite.scala | 4 +- .../execution/joins/HashedRelationSuite.scala | 4 +- .../sql/execution/joins/InnerJoinSuite.scala | 3 +- .../sql/execution/joins/OuterJoinSuite.scala | 3 +- .../sql/execution/joins/SemiJoinSuite.scala | 7 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 4 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 4 +- .../sources/CreateTableAsSelectSuite.scala | 6 ++- .../sql/sources/DDLSourceLoadSuite.scala | 3 +- .../spark/sql/sources/DDLTestSuite.scala | 5 +- .../spark/sql/sources/DataSourceTest.scala | 7 +-- .../spark/sql/sources/FilteredScanSuite.scala | 5 +- .../spark/sql/sources/InsertSuite.scala | 5 +- .../sql/sources/PartitionedWriteSuite.scala | 4 +- .../spark/sql/sources/PrunedScanSuite.scala | 5 +- .../spark/sql/sources/SaveLoadSuite.scala | 5 +- .../spark/sql/sources/TableScanSuite.scala | 5 +- .../apache/spark/sql/test/SQLTestUtils.scala | 17 +++---- .../spark/sql/test/SharedSQLContext.scala | 45 ++++-------------- .../sql/hive/test/HiveDataSourceTest.scala | 28 ----------- .../test/HiveParquetCompatibilityTest.scala | 28 ----------- .../spark/sql/hive/test/HiveParquetTest.scala | 26 ----------- .../sql/hive/test/HiveSparkPlanTest.scala | 28 ----------- .../spark/sql/hive/test/HiveTestUtils.scala | 30 ------------ .../sql/hive/test/SharedHiveContext.scala | 46 ++++--------------- .../spark/sql/hive/CachedTableSuite.scala | 4 +- .../spark/sql/hive/ErrorPositionSuite.scala | 4 +- .../hive/HiveDataFrameAnalyticsSuite.scala | 4 +- .../sql/hive/HiveDataFrameJoinSuite.scala | 4 +- .../sql/hive/HiveDataFrameWindowSuite.scala | 4 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 7 +-- .../spark/sql/hive/HiveParquetSuite.scala | 5 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 4 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 4 +- .../spark/sql/hive/ListTablesSuite.scala | 4 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 4 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 4 +- .../hive/ParquetHiveCompatibilitySuite.scala | 4 +- .../spark/sql/hive/QueryPartitionSuite.scala | 4 +- .../spark/sql/hive/SerializationSuite.scala | 4 +- .../spark/sql/hive/StatisticsSuite.scala | 4 +- .../org/apache/spark/sql/hive/UDFSuite.scala | 4 +- .../execution/AggregationQuerySuite.scala | 4 +- .../hive/execution/HiveComparisonTest.scala | 4 +- .../sql/hive/execution/HiveExplainSuite.scala | 4 +- .../HiveOperatorQueryableSuite.scala | 4 +- .../sql/hive/execution/HivePlanTest.scala | 4 +- .../sql/hive/execution/HiveUDFSuite.scala | 4 +- .../sql/hive/execution/SQLQuerySuite.scala | 4 +- .../execution/ScriptTransformationSuite.scala | 6 +-- .../hive/orc/OrcPartitionDiscoverySuite.scala | 4 +- .../spark/sql/hive/orc/OrcSourceSuite.scala | 4 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 4 +- .../apache/spark/sql/hive/parquetSuites.scala | 4 +- .../CommitFailureTestRelationSuite.scala | 4 +- .../sql/sources/hadoopFsRelationSuites.scala | 4 +- 102 files changed, 231 insertions(+), 427 deletions(-) delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveDataSourceTest.scala delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index ce2896b2d51a..0f20c775c120 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -26,11 +26,11 @@ import org.apache.spark.Accumulators import org.apache.spark.sql.columnar._ import org.apache.spark.sql.functions._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext private case class BigData(s: String) -class CachedTableSuite extends QueryTest with SQLTestUtils { +class CachedTableSuite extends QueryTest with SharedSQLContext { import testImplicits._ def rddIdOf(tableName: String): Int = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 2ecd6ee445de..df1e784dce0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,9 +22,9 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class ColumnExpressionSuite extends QueryTest with SQLTestUtils { +class ColumnExpressionSuite extends QueryTest with SharedSQLContext { import testImplicits._ private lazy val booleanData = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 29e0b0805c20..72cf7aab0b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType -class DataFrameAggregateSuite extends QueryTest with SQLTestUtils { +class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("groupBy") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 096f693c06c1..7d6ef5a1f085 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest with SQLTestUtils { +class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("array with column name") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index b52a222c0a57..e5d7d63441a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameImplicitsSuite extends QueryTest with SQLTestUtils { +class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("RDD of tuples") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 49543ea5b481..9d7cb2de67b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameJoinSuite extends QueryTest with SQLTestUtils { +class DataFrameJoinSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("join - join using") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index eb54287558ef..cdaa14ac8078 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameNaFunctionsSuite extends QueryTest with SQLTestUtils { +class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ def createDF(): DataFrame = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 5385f1e0dc64..28bdd6f83b68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql import java.util.Random import org.apache.spark.sql.functions.col -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameStatSuite extends QueryTest with SQLTestUtils { +class DataFrameStatSuite extends QueryTest with SharedSQLContext { import testImplicits._ private def toLetter(i: Int): String = (i + 97).toChar.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e81e2d2f4327..2e26eccc9ebe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext} -class DataFrameSuite extends QueryTest with SQLTestUtils { +class DataFrameSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("analysis error should be eagerly reported") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index b2aacc04755b..77907e91363e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ * This is here for now so I can make sure Tungsten project is tested without refactoring existing * end-to-end test infra. In the long run this should just go away. */ -class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { +class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("test simple types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index bc0f9ca33abd..8cfa9189ef07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,10 +22,10 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval -class DateFunctionsSuite extends QueryTest with SQLTestUtils { +class DateFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("function current_date") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 11f8f7a10a75..1a963f33ed60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class JoinSuite extends QueryTest with BeforeAndAfterEach with SQLTestUtils { +class JoinSuite extends QueryTest with BeforeAndAfterEach with SharedSQLContext { import testImplicits._ setupTestData() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index dab4b9c65295..045fea82e4c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class JsonFunctionsSuite extends QueryTest with SQLTestUtils { +class JsonFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("function get_json_object") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index d73f2f4e02dd..2a80cab0bc51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -class ListTablesSuite extends QueryTest with BeforeAndAfter with SQLTestUtils { +class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { import testImplicits._ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 617ce83d7ff1..455bf306d7b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest with SQLTestUtils { +class MathExpressionsSuite extends QueryTest with SharedSQLContext { import MathExpressionsTestData._ import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index b22944e37ef4..795d4e983f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite with SQLTestUtils { +class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index bcb6bd887bcc..fa2aabb4f2fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class SQLConfSuite extends QueryTest with SQLTestUtils { +class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" private val testVal = "test.val.0" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 23fe33d61f80..056570a0c978 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class SQLContextSuite extends SparkFunSuite with SQLTestUtils { +class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { SQLContext.setLastInstantiatedContext(ctx) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a7a154ad07fc..312f6f008080 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -25,14 +25,14 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with SQLTestUtils { +class SQLQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ setupTestData() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index e8b8224343b0..c1ae8d04fab1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext case class ReflectData( stringField: String, @@ -72,7 +72,7 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite with SQLTestUtils { +class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("query case class RDD") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 0f9385f895bf..cc95eede005d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.Decimal -class StringFunctionsSuite extends QueryTest with SQLTestUtils { +class StringFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("string concat") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 77ecdcda2441..46056c16533f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ private case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with SQLTestUtils { +class UDFSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("built-in fixed arity expressions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index d8d684e56417..de637628debc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,7 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -67,7 +67,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest with SQLTestUtils { +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ private lazy val pointsRDD = Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index bba4e5721684..261a1878ac7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.columnar import java.sql.{Date, Timestamp} import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel.MEMORY_ONLY -class InMemoryColumnarQuerySuite extends QueryTest with SQLTestUtils { +class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ setupTestData() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 483d3e5b07be..821160c0abd2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with SQLTestUtils { +class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { import testImplicits._ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 79e903c2bbd4..8998f5111124 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.test.SharedSQLContext -class ExchangeSuite extends SparkPlanTest { +class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 31fde3a03173..1f36c28bd28b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -27,11 +27,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class PlannerSuite extends SparkFunSuite with SQLTestUtils { +class PlannerSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ setupTestData() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index b9773a1e7bfb..ef6ad59b71fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -21,10 +21,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType} import org.apache.spark.unsafe.types.UTF8String -class RowFormatConvertersSuite extends SparkPlanTest { +class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { case c: ConvertToUnsafe => c diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index a2c10fdaf6cd..8fa77b0fcb7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext -class SortSuite extends SparkPlanTest { +class SortSuite extends SparkPlanTest with SharedSQLContext { // This test was originally added as an example of how to use [[SparkPlanTest]]; // it's not designed to be a comprehensive test of ExternalSort. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 29cb920e6c05..3a87f374d94b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -25,18 +25,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.SQLTestUtils /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -private[sql] abstract class SparkPlanTest extends AbstractSparkPlanTest with SQLTestUtils - -/** - * Helper class for testing individual physical operators with a pluggable [[SQLContext]]. - */ -private[sql] abstract class AbstractSparkPlanTest extends SparkFunSuite { +private[sql] abstract class SparkPlanTest extends SparkFunSuite { protected def _sqlContext: SQLContext /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 7f6651052eba..20eef186a46f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -22,12 +22,13 @@ import scala.util.Random import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * A test suite that generates randomized data to test the [[TungstenSort]] operator. */ -class TungstenSortSuite extends SparkPlanTest { +class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 11a242bf9402..d1f0b2b1fc52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -39,7 +39,7 @@ import org.apache.spark.unsafe.types.UTF8String class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers - with SQLTestUtils { + with SharedSQLContext { import UnsafeFixedWidthAggregationMap._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 5620d202197c..d3be568a8758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -23,14 +23,14 @@ import org.apache.spark._ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. */ -class UnsafeKVExternalSorterSuite extends SparkFunSuite with SQLTestUtils { +class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 5385cee0c461..103c3d510177 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -24,15 +24,15 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf, SQLContext} +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { +class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ test("Type promotion") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index bffa78cdd5b0..69ec7a11841d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -27,8 +27,9 @@ import org.apache.parquet.avro.AvroParquetWriter import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat} +import org.apache.spark.sql.test.SharedSQLContext -class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { +class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ override protected def beforeAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 11bed317e1d0..a6c04c8df8dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -21,7 +21,6 @@ import java.io.File import scala.collection.JavaConversions._ -import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType @@ -32,18 +31,7 @@ import org.apache.spark.util.Utils /** * Helper class for testing Parquet compatibility. */ -private[sql] abstract class ParquetCompatibilityTest - extends AbstractParquetCompatibilityTest - with ParquetTest - -/** - * Abstract helper class for testing Parquet compatibility with a pluggable - * [[org.apache.spark.sql.SQLContext]]. - */ -private[sql] abstract class AbstractParquetCompatibilityTest - extends QueryTest - with AbstractParquetTest - with BeforeAndAfterAll { +private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { protected var parquetStore: File = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index e4b9633b7201..5b4e568bb983 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -39,7 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuite extends QueryTest with ParquetTest { +class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private def checkFilterPredicate( df: DataFrame, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 8b3bbe6e199b..a5279bb6a08d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport @@ -62,7 +63,7 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuite extends QueryTest with ParquetTest { +class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { import testImplicits._ /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index a2f187dbccbc..be3afaa87abc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -26,11 +26,12 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql._ import org.apache.spark.unsafe.types.UTF8String // The data where the partitioning key exists only in the directory structure. @@ -39,7 +40,7 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) -class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext { import PartitioningUtils._ import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala index 7cab6a415ac4..b290429c2a02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext -class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest { +class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { private def readParquetProtobufFile(name: String): DataFrame = { val url = Thread.currentThread().getContextClassLoader.getResource(name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2287d2a6d467..11c5818657a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -21,14 +21,15 @@ import java.io.File import org.apache.hadoop.fs.Path -import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuite extends QueryTest with ParquetTest { +class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index a50eb66e88cd..9dcbc1a047be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,11 +22,11 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { +abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 61e2742bb0f3..5dbc7d1630f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -22,8 +22,7 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.{AbstractSQLTestUtils, SQLTestUtils} +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} /** @@ -33,12 +32,7 @@ import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends AbstractParquetTest with SQLTestUtils - -/** - * Abstract helper trait for Parquet tests with a pluggable [[SQLContext]]. - */ -private[sql] trait AbstractParquetTest extends SparkFunSuite with AbstractSQLTestUtils { +private[sql] trait ParquetTest extends SQLTestUtils { protected def _sqlContext: SQLContext /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index ed6cf82e6b7b..b789c5a106e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext -class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest { +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ private val parquetFilePath = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 0c7a29a0469c..22189477d277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class DebuggingSuite extends SparkFunSuite with SQLTestUtils { +class DebuggingSuite extends SparkFunSuite with SharedSQLContext { test("DataFrame.debug()") { testData.debug() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index aa62e66b4a2d..ae3bc244686e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -23,12 +23,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends SparkFunSuite with SQLTestUtils { +class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { // Key is simply the record itself private val keyProjection = new Projection { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index f0c96eab0f5a..a7ef11c5b92b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -23,9 +23,10 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -class InnerJoinSuite extends SparkPlanTest { +class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { private def testInnerJoin( testName: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 7d2d3449982c..a1a617d7b739 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -23,9 +23,10 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} -class OuterJoinSuite extends SparkPlanTest { +class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { private lazy val left = ctx.createDataFrame( ctx.sparkContext.parallelize(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index d9f83ad11e42..cfe6f57d56cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{SQLConf, DataFrame, Row} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} -import org.apache.spark.sql.{SQLConf, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} -class SemiJoinSuite extends SparkPlanTest { +class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { private lazy val left = ctx.createDataFrame( ctx.sparkContext.parallelize(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 78c446b4631a..80006bf077fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -28,11 +28,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { +class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("LongSQLMetric should not box Long") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 71fa58f2b95c..80d1e8895694 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class SQLListenerSuite extends SparkFunSuite with SQLTestUtils { +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ private def createTestDataFrame: DataFrame = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index e249f9b89f01..5ac409d53649 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,11 +25,11 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SQLTestUtils { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { import testImplicits._ val url = "jdbc:h2:mem:testdb0" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index f2431d892e68..78f521d380f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -24,11 +24,11 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{Row, SaveMode} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SQLTestUtils { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 6ec4f0a7e459..82e9ce954d4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.sources import java.io.{File, IOException} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DDLException +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { + +class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 392da0b0826b..853707c036c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType} // please note that the META-INF/services had to be modified for the test directory for this to work -class DDLSourceLoadSuite extends DataSourceTest { +class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { test("data sources with the same name") { intercept[RuntimeException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 84855ce45e91..6d9a61175c9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.sources +import org.scalatest.BeforeAndAfter + import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -68,7 +71,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } } -class DDLTestSuite extends DataSourceTest { +class DDLTestSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { before { caseInsensitiveContext.sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 0b56f9a809f7..d74d29fb0beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,15 +17,10 @@ package org.apache.spark.sql.sources -import org.scalatest.BeforeAndAfter - import org.apache.spark.sql._ -import org.apache.spark.sql.test.SQLTestUtils - -private[sql] abstract class DataSourceTest extends AbstractDataSourceTest with SQLTestUtils -private[sql] abstract class AbstractDataSourceTest extends QueryTest with BeforeAndAfter { +private[sql] abstract class DataSourceTest extends QueryTest { protected def _sqlContext: SQLContext // We want to test some edge cases. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 2950c058e297..613467e59efb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.scalatest.BeforeAndAfter + import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -96,7 +99,7 @@ object FiltersPushed { var list: Seq[Filter] = Nil } -class FilteredScanSuite extends DataSourceTest { +class FilteredScanSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { before { caseInsensitiveContext.sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index d0c533cc4f59..c880d5f7df6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class InsertSuite extends DataSourceTest with BeforeAndAfterAll { +class InsertSuite extends DataSourceTest with SharedSQLContext { private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 47f7d1605002..79b6e9b45c00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class PartitionedWriteSuite extends QueryTest with SQLTestUtils { +class PartitionedWriteSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("write many partitions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 0d5183444af7..6130780e8044 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.scalatest.BeforeAndAfter + import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class PrunedScanSource extends RelationProvider { @@ -51,7 +54,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } } -class PrunedScanSuite extends DataSourceTest { +class PrunedScanSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { before { caseInsensitiveContext.sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 5be439fcdca4..ba41e89738f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,11 +19,14 @@ package org.apache.spark.sql.sources import java.io.File +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class SaveLoadSuite extends DataSourceTest { +class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { private lazy val sparkContext = caseInsensitiveContext.sparkContext private var originalDefaultSource: String = null private var path: File = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 7b15d3fa3d74..5bde48740c90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -20,8 +20,11 @@ package org.apache.spark.sql.sources import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import org.scalatest.BeforeAndAfter + import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource @@ -95,7 +98,7 @@ case class AllDataTypesScan( } } -class TableScanSuite extends DataSourceTest { +class TableScanSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( s"str_$i", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index dbe75654e56f..8c1d9c180e7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -31,22 +31,17 @@ import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.util.Utils -/** - * Helper trait that should be extended by all SQL test suites involving a - * [[org.apache.spark.sql.SQLContext]]. - */ -private[sql] trait SQLTestUtils extends AbstractSQLTestUtils with SharedSQLContext { - protected final override def _sqlContext = sqlContext -} - /** * Helper trait that should be extended by all SQL test suites. * - * This base trait allows subclasses to plugin a custom [[SQLContext]]. It comes with test - * data prepared in advance as well as all implicit conversions used extensively by dataframes. + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait AbstractSQLTestUtils +private[sql] trait SQLTestUtils extends SparkFunSuite with BeforeAndAfterAll with SQLTestData { self => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e3eb0114fb1b..c1e5939b83a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.test -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.SQLContext -import org.apache.spark.SparkFunSuite /** * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. */ -private[sql] trait SharedSQLContext extends SparkFunSuite with BeforeAndAfterAll { +private[sql] trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. @@ -34,6 +33,13 @@ private[sql] trait SharedSQLContext extends SparkFunSuite with BeforeAndAfterAll */ private var _ctx: TestSQLContext = null + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected def ctx: TestSQLContext = _ctx + protected def sqlContext: TestSQLContext = _ctx + protected override def _sqlContext: SQLContext = _ctx + /** * Initialize the [[TestSQLContext]]. * This is a no-op if the user explicitly switched to a custom context before this is called. @@ -56,37 +62,4 @@ private[sql] trait SharedSQLContext extends SparkFunSuite with BeforeAndAfterAll super.afterAll() } - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected def ctx: TestSQLContext = _ctx - protected def sqlContext: TestSQLContext = _ctx - - /** - * Switch to a custom [[TestSQLContext]]. - * - * This stops the underlying [[org.apache.spark.SparkContext]] and expects a new one - * to be created. This is necessary because only one [[org.apache.spark.SparkContext]] - * is allowed per JVM. - */ - protected def switchSQLContext(newContext: () => TestSQLContext): Unit = { - if (_ctx != null) { - _ctx.sparkContext.stop() - } - _ctx = newContext() - } - - /** - * Execute the given block of code with a custom [[TestSQLContext]]. - * At the end of the method, the default [[TestSQLContext]] will be restored. - */ - protected def withSQLContext[T](newContext: () => TestSQLContext)(body: => T) { - switchSQLContext(newContext) - try { - body - } finally { - switchSQLContext(() => new TestSQLContext) - } - } - } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveDataSourceTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveDataSourceTest.scala deleted file mode 100644 index e011d86c69ae..000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveDataSourceTest.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - -import org.apache.spark.sql.sources.AbstractDataSourceTest - - -/** - * An equivalent of [[org.apache.spark.sql.sources.DataSourceTest]], but for hive tests. - */ -private[hive] abstract class HiveDataSourceTest - extends AbstractDataSourceTest - with HiveTestUtils diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala deleted file mode 100644 index 468e712f64a1..000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetCompatibilityTest.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - -import org.apache.spark.sql.execution.datasources.parquet.AbstractParquetCompatibilityTest - -/** - * Helper class for testing Parquet compatibility in hive. - * This is analogous to [[org.apache.spark.sql.parquet.ParquetCompatibilityTest]]. - */ -private[hive] abstract class HiveParquetCompatibilityTest - extends AbstractParquetCompatibilityTest - with HiveParquetTest diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala deleted file mode 100644 index 9f053b4b855a..000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveParquetTest.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - -import org.apache.spark.sql.execution.datasources.parquet.AbstractParquetTest - -/** - * Helper trait for Parquet tests analogous to [[org.apache.spark.sql.parquet.ParquetTest]]. - */ -private[hive] trait HiveParquetTest extends AbstractParquetTest with HiveTestUtils - diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala deleted file mode 100644 index c31ee732670c..000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveSparkPlanTest.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - -import org.apache.spark.sql.execution.AbstractSparkPlanTest - -/** - * Base class for writing tests for individual physical operators in hive. - * This is analogous to [[org.apache.spark.sql.execution.SparkPlanTest]]. - */ -private[sql] abstract class HiveSparkPlanTest - extends AbstractSparkPlanTest - with HiveTestUtils diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala deleted file mode 100644 index 3b12e96be5e9..000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/HiveTestUtils.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - -import org.apache.spark.sql.test.AbstractSQLTestUtils - -/** - * Helper trait that should be extended by all SQL test suites involving a - * [[org.apache.spark.sql.hive.HiveContext]]. - * - * This is analogous to [[org.apache.spark.sql.test.SQLTestUtils]] but for hive tests. - */ -private[spark] trait HiveTestUtils extends AbstractSQLTestUtils with SharedHiveContext { - protected final override def _sqlContext = hiveContext -} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala index fb1b57b83d76..b018693226bf 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql.hive.test -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SQLTestUtils /** * Helper trait for hive test suites where all tests share a single [[TestHiveContext]]. * This is analogous to [[org.apache.spark.sql.test.SharedSQLContext]]. */ -private[spark] trait SharedHiveContext extends SparkFunSuite with BeforeAndAfterAll { +private[spark] trait SharedHiveContext extends SQLTestUtils { /** * The [[TestHiveContext]] to use for all tests in this suite. @@ -36,6 +34,13 @@ private[spark] trait SharedHiveContext extends SparkFunSuite with BeforeAndAfter */ private var _ctx: TestHiveContext = null + /** + * The [[TestHiveContext]] to use for all tests in this suite. + */ + protected def ctx: TestHiveContext = _ctx + protected def hiveContext: TestHiveContext = _ctx + protected override def _sqlContext: TestHiveContext = _ctx + /** * Initialize the [[TestHiveContext]]. * This is a no-op if the user explicitly switched to a custom context before this is called. @@ -57,37 +62,4 @@ private[spark] trait SharedHiveContext extends SparkFunSuite with BeforeAndAfter super.afterAll() } - /** - * The [[TestHiveContext]] to use for all tests in this suite. - */ - protected def ctx: TestHiveContext = _ctx - protected def hiveContext: TestHiveContext = _ctx - - /** - * Switch a custom [[TestHiveContext]]. - * - * This stops the underlying [[org.apache.spark.SparkContext]] and expects a new one - * to be created. This is needed because only one [[org.apache.spark.SparkContext]] - * is allowed per JVM. - */ - protected def switchHiveContext(newContext: () => TestHiveContext): Unit = { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = newContext() - } - } - - /** - * Execute the given block of code with a custom [[TestHiveContext]]. - * At the end of the method, the default [[TestHiveContext]] will be restored. - */ - protected def withHiveContext[T](newContext: () => TestHiveContext)(body: => T) { - switchHiveContext(newContext) - try { - body - } finally { - switchHiveContext(() => new TestHiveContext) - } - } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 1be12bf9a884..99567d667b95 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -21,11 +21,11 @@ import java.io.File import org.apache.spark.sql.{SaveMode, AnalysisException, QueryTest} import org.apache.spark.sql.columnar.InMemoryColumnarTableScan -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest with HiveTestUtils { +class CachedTableSuite extends QueryTest with SharedHiveContext { def rddIdOf(tableName: String): Int = { val executedPlan = ctx.table(tableName).queryExecution.executedPlan diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 77eacc66e5e4..7519e8bebe54 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,11 +22,11 @@ import scala.util.Try import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.{AnalysisException, QueryTest} -class ErrorPositionSuite extends QueryTest with BeforeAndAfter with HiveTestUtils { +class ErrorPositionSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { import testImplicits._ before { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 0d94bb3cbf78..8202722cb785 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with HiveTestUtils { +class HiveDataFrameAnalyticsSuite extends QueryTest with SharedHiveContext { import testImplicits._ private var _testData: DataFrame = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index dbed3bce6d55..6d0ef530ca65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext -class HiveDataFrameJoinSuite extends QueryTest with HiveTestUtils { +class HiveDataFrameJoinSuite extends QueryTest with SharedHiveContext { import testImplicits._ // We should move this into SQL package if we make case sensitivity configurable in SQL. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index c15bf1cd95d5..4def9557e2ae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext -class HiveDataFrameWindowSuite extends QueryTest with HiveTestUtils { +class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { import testImplicits._ test("reuse window partitionBy") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index bba54329c5b2..7d4ac7e2c430 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -22,12 +22,13 @@ import java.io.File import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} -import org.apache.spark.sql.hive.test.{HiveDataSourceTest, HiveTestUtils} +import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.sources.DataSourceTest import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends SparkFunSuite with HiveTestUtils with Logging { +class HiveMetastoreCatalogSuite extends SparkFunSuite with SharedHiveContext with Logging { import testImplicits._ test("struct field should accept underscore in sub-column name") { @@ -50,7 +51,7 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with HiveTestUtils with Lo } } -class DataSourceWithHiveMetastoreCatalogSuite extends HiveDataSourceTest { +class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SharedHiveContext { import testImplicits._ private lazy val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index f763d007a10a..b90e52d373b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.HiveParquetTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.hive.test.SharedHiveContext case class Cases(lower: String, UPPER: String) -class HiveParquetSuite extends QueryTest with HiveParquetTest { +class HiveParquetSuite extends QueryTest with ParquetTest with SharedHiveContext { test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 4dadb3902be6..4bd8bac0981b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.hive.test.{TestHiveContext, HiveTestUtils} +import org.apache.spark.sql.hive.test.{TestHiveContext, SharedHiveContext} import org.apache.spark.util.{ResetSystemProperties, Utils} /** @@ -39,7 +39,7 @@ class HiveSparkSubmitSuite with Matchers with ResetSystemProperties with Timeouts - with HiveTestUtils { + with SharedHiveContext { // TODO: rewrite these or mark them as slow tests to be run sparingly diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 5e440411f89c..fdae068d3ccd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.execution.QueryExecutionException -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.test.SQLTestData.TestData import org.apache.spark.sql.types._ import org.apache.spark.util.Utils case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with HiveTestUtils { +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { import testImplicits._ private val _testData = ctx.sparkContext.parallelize( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 4aa97be78ac8..d5db730d04c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext -class ListTablesSuite extends QueryTest with HiveTestUtils { +class ListTablesSuite extends QueryTest with SharedHiveContext { import testImplicits._ val df = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index a0a929adb4b8..faac6f5a1aa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -37,7 +37,7 @@ import org.apache.spark.util.Utils */ class MetastoreDataSourcesSuite extends QueryTest - with HiveTestUtils + with SharedHiveContext with Logging { import testImplicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 997eda8fde36..1ff256d3b209 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.{QueryTest, SaveMode} -class MultiDatabaseSuite extends QueryTest with HiveTestUtils { +class MultiDatabaseSuite extends QueryTest with SharedHiveContext { private lazy val df = ctx.range(10).coalesce(1) test(s"saveAsTable() to non-default database - with USE - Overwrite") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index e13cda1a7926..8dca1a77c8aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -21,9 +21,9 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.spark.sql.{Row, SQLConf} import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest -import org.apache.spark.sql.hive.test.HiveParquetCompatibilityTest +import org.apache.spark.sql.hive.test.SharedHiveContext -class ParquetHiveCompatibilitySuite extends HiveParquetCompatibilityTest { +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with SharedHiveContext { import ParquetCompatibilityTest.makeNullable /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 910ec4654cf8..418205aee0dd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.test.SQLTestData.TestData import org.apache.spark.sql.QueryTest import org.apache.spark.util.Utils -class QueryPartitionSuite extends QueryTest with HiveTestUtils { +class QueryPartitionSuite extends QueryTest with SharedHiveContext { import testImplicits._ test("SPARK-5068: query data when path doesn't exist"){ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index b7a8ba493687..7ce385f1a613 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.hive import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext -class SerializationSuite extends SparkFunSuite with HiveTestUtils { +class SerializationSuite extends SparkFunSuite with SharedHiveContext { test("[SPARK-5840] HiveContext should be serializable") { ctx.hiveconf diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f6963c917d27..7dcedaa02153 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,9 +22,9 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext -class StatisticsSuite extends QueryTest with HiveTestUtils { +class StatisticsSuite extends QueryTest with SharedHiveContext { protected override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 954e4201d923..280587076755 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with HiveTestUtils { +class UDFSuite extends QueryTest with SharedHiveContext { test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 0a79d33b1011..88803d059bea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} -abstract class AggregationQuerySuite extends QueryTest with HiveTestUtils { +abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { import testImplicits._ var originalUseAggregate2: Boolean = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 99f139047b1f..6d391c57b8c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext /** * Allows the creations of tests that execute the same query against both hive @@ -43,7 +43,7 @@ import org.apache.spark.sql.hive.test.HiveTestUtils abstract class HiveComparisonTest extends SparkFunSuite with GivenWhenThen - with HiveTestUtils + with SharedHiveContext with Logging { /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 1c9941949a9e..cfc997395d3b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest with HiveTestUtils { +class HiveExplainSuite extends QueryTest with SharedHiveContext { test("explain extended command") { checkExistence(ctx.sql(" explain select * from src where key=123 "), true, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index a35548a6a979..63876dc5f0cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext /** * A set of tests that validates commands can also be queried by like a table */ -class HiveOperatorQueryableSuite extends QueryTest with HiveTestUtils { +class HiveOperatorQueryableSuite extends QueryTest with SharedHiveContext { test("SPARK-5324 query result of describe command") { ctx.loadTestTable("src") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index 3b9ed9107363..cd1bfa43ecca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -21,9 +21,9 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext -class HivePlanTest extends QueryTest with HiveTestUtils { +class HivePlanTest extends QueryTest with SharedHiveContext { import testImplicits._ test("udf constant folding") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index c4384dbcc7d1..8fbeb3498fef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.util.Utils case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) @@ -46,7 +46,7 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest with HiveTestUtils { +class HiveUDFSuite extends QueryTest with SharedHiveContext { import testImplicits._ test("spark sql udf test that returns a struct") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3cdb4d0218b1..4fd877970b05 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestData.TestData @@ -63,7 +63,7 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest with HiveTestUtils { +class SQLQuerySuite extends QueryTest with SharedHiveContext { import testImplicits._ test("UDTF") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 72457f1d2390..53678ce73302 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -24,11 +24,11 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} -import org.apache.spark.sql.hive.test.HiveSparkPlanTest +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends HiveSparkPlanTest { +class ScriptTransformationSuite extends SparkPlanTest with SharedHiveContext { private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 0cda8cb03115..dfbbc21539a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -25,7 +25,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -34,7 +34,7 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with HiveTestUtils { +class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { import testImplicits._ val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 40e2bc71b3f0..a3abe8882d87 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with HiveTestUtils { +abstract class OrcSuite extends QueryTest with SharedHiveContext { import testImplicits._ var orcTableDir: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 09518696b974..d974011c4699 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -24,9 +24,9 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext -private[sql] trait OrcTest extends SparkFunSuite with HiveTestUtils { +private[sql] trait OrcTest extends SparkFunSuite with SharedHiveContext { import testImplicits._ /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index ed474ad141ba..204eed161e1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -684,7 +684,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with HiveTestUtils { +abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext { import testImplicits._ var partitionedTableDir: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index 8bd7dba84ec3..6ab7bc4fde90 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext -class CommitFailureTestRelationSuite extends SparkFunSuite with HiveTestUtils { +class CommitFailureTestRelationSuite extends SparkFunSuite with SharedHiveContext { // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 010304ac4912..37abeade74e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -28,11 +28,11 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.sql.types._ -abstract class HadoopFsRelationTest extends QueryTest with HiveTestUtils { +abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { import testImplicits._ val dataSourceName: String From 1cf53adefc1e8300a6e41f7f6e7d14112373344e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 16:51:08 -0700 Subject: [PATCH 19/39] Create new context in SBT console by default This restores the old behavior if the developer drops into console mode from SBT. --- project/SparkBuild.scala | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 56e19a9ddaf7..d976a19f4a9b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -315,11 +315,12 @@ object OldDeps { ) } -// TODO: check if this is OK object SQL { lazy val settings = Seq( initialCommands in console := """ + |import org.apache.spark.SparkContext + |import org.apache.spark.sql.SQLContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -329,14 +330,19 @@ object SQL { |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.types._""".stripMargin, - cleanupCommands in console := "sparkContext.stop()" + |import org.apache.spark.sql.types._ + | + |val sc = new SparkContext("local[*]", "dev-shell") + |val sqlContext = new SQLContext(sc) + |import sqlContext.implicits._ + |import sqlContext._ + """.stripMargin, + cleanupCommands in console := "sc.stop()" ) } object Hive { - // TODO: check me, will this work? lazy val settings = Seq( javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them @@ -348,6 +354,7 @@ object Hive { }, initialCommands in console := """ + |import org.apache.spark.SparkContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -358,8 +365,14 @@ object Hive { |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.hive._ - |import org.apache.spark.sql.types._""".stripMargin, - cleanupCommands in console := "sparkContext.stop()", + |import org.apache.spark.sql.types._ + | + |val sc = new SparkContext("local[*]", "dev-shell") + |val hc = new HiveContext(sc) + |import hc.implicits._ + |import hc._ + """.stripMargin, + cleanupCommands in console := "sc.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce // in order to generate golden files. This is only required for developers who are adding new // new query tests. From bc5c9992a840296b2930622a7eab8f8d5a84bcf8 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 18:06:22 -0700 Subject: [PATCH 20/39] Fix SemiJoinSuite --- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../spark/sql/ColumnExpressionSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../datasources/json/JsonSuite.scala | 74 +++++++++---------- .../sql/execution/joins/InnerJoinSuite.scala | 2 +- .../sql/execution/joins/SemiJoinSuite.scala | 52 +++++++------ 7 files changed, 72 insertions(+), 64 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0f20c775c120..b412af9d5112 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,8 +25,8 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.columnar._ import org.apache.spark.sql.functions._ -import org.apache.spark.storage.{StorageLevel, RDDBlockId} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.storage.{StorageLevel, RDDBlockId} private case class BigData(s: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index df1e784dce0a..053c0f052d9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -21,8 +21,8 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2e26eccc9ebe..872bae647e84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -310,7 +310,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("udf") { - val foo = org.apache.spark.sql.functions.udf((a: Int, b: String) => a.toString + b) + val foo = udf((a: Int, b: String) => a.toString + b) checkAnswer( // SELECT *, foo(key, value) FROM testData diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 821160c0abd2..d029321598cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { +class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext with BeforeAndAfter { import testImplicits._ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 103c3d510177..636f55763a21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -215,7 +215,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring with null in sampling") { - val jsonDF = _sqlContext.read.json(jsonNullStruct) + val jsonDF = ctx.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -234,7 +234,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = _sqlContext.read.json(primitiveFieldAndType) + val jsonDF = ctx.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -262,7 +262,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = _sqlContext.read.json(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -363,7 +363,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = _sqlContext.read.json(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -379,7 +379,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = _sqlContext.read.json(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -451,7 +451,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = _sqlContext.read.json(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -504,7 +504,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in complex field values") { - val jsonDF = _sqlContext.read.json(complexFieldValueTypeConflict) + val jsonDF = ctx.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -528,7 +528,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = _sqlContext.read.json(arrayElementTypeConflict) + val jsonDF = ctx.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -556,7 +556,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Handling missing fields") { - val jsonDF = _sqlContext.read.json(missingFields) + val jsonDF = ctx.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -577,7 +577,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalFile.toURI.toString ctx.sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = _sqlContext.read.option("samplingRatio", "0.49").json(path) + val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -592,7 +592,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - _sqlContext.read.schema(schema).json(path) + ctx.read.schema(schema).json(path) .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) @@ -605,7 +605,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = _sqlContext.read.json(path) + val jsonDF = ctx.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -674,7 +674,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = _sqlContext.read.schema(schema).json(path) + val jsonDF1 = ctx.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -691,7 +691,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val jsonDF2 = _sqlContext.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -712,7 +712,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = _sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -740,7 +740,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = _sqlContext.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -766,7 +766,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = _sqlContext.read.json(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -784,7 +784,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = _sqlContext.read.json(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -807,7 +807,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = _sqlContext.read.json(jsonArray) + val jsonDF = ctx.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -825,10 +825,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = _sqlContext.conf.columnNameOfCorruptRecord - _sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonDF = _sqlContext.read.json(corruptRecords) + val jsonDF = ctx.read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -878,11 +878,11 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row("]") :: Nil ) - _sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } test("SPARK-4068: nulls in arrays") { - val jsonDF = _sqlContext.read.json(nullsInArrays) + val jsonDF = ctx.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -928,7 +928,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = _sqlContext.createDataFrame(rowRDD1, schema1) + val df1 = ctx.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -951,7 +951,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = _sqlContext.createDataFrame(rowRDD2, schema2) + val df3 = ctx.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -959,8 +959,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = _sqlContext.read.json(primitiveFieldAndType) - val primTable = _sqlContext.read.json(jsonDF.toJSON) + val jsonDF = ctx.read.json(primitiveFieldAndType) + val primTable = ctx.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( ctx.sql("select * from primativeTable"), @@ -972,8 +972,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val complexJsonDF = _sqlContext.read.json(complexFieldAndType1) - val compTable = _sqlContext.read.json(complexJsonDF.toJSON) + val complexJsonDF = ctx.read.json(complexFieldAndType1) + val compTable = ctx.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1108,24 +1108,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val oldColumnNameOfCorruptRecord = _sqlContext.conf.columnNameOfCorruptRecord - _sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) try { val temp = Utils.createTempDir().getPath - val df = _sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) df.write.mode("overwrite").parquet(temp) // order of MapType is not defined - assert(_sqlContext.read.parquet(temp).count() == 5) + assert(ctx.read.parquet(temp).count() == 5) - val df2 = _sqlContext.read.json(corruptRecords) + val df2 = ctx.read.json(corruptRecords) df2.write.mode("overwrite").parquet(temp) - checkAnswer(_sqlContext.read.parquet(temp), df2.collect()) + checkAnswer(ctx.read.parquet(temp), df2.collect()) } finally { - _sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index a7ef11c5b92b..d53dcc6a675f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame} +import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index cfe6f57d56cf..baa86e320d98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -51,38 +51,46 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) - private val condition = { + private lazy val condition = { And((left.col("a") === right.col("c")).expr, LessThan(left.col("b").expr, right.col("d").expr)) } + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable private def testLeftSemiJoin( testName: String, - leftRows: DataFrame, - rightRows: DataFrame, - condition: Expression, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, expectedAnswer: Seq[Product]): Unit = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - test(s"$testName using LeftSemiJoinHash") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext).apply( - LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using LeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } + } - test(s"$testName using BroadcastLeftSemiJoinHash") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + test(s"$testName using BroadcastLeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } } test(s"$testName using LeftSemiJoinBNL") { From 19fd6c3910c52d366b102916d90b334ea2f0bd69 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 18:20:03 -0700 Subject: [PATCH 21/39] Fix InnerJoinSuite --- .../sql/execution/joins/InnerJoinSuite.scala | 280 ++++++++++-------- 1 file changed, 155 insertions(+), 125 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index d53dcc6a675f..cc649b9bd4c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -28,157 +28,187 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { + private lazy val myUpperCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + private lazy val myLowerCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) + + private lazy val myTestData = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable private def testInnerJoin( testName: String, - leftRows: DataFrame, - rightRows: DataFrame, - condition: Expression, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: () => Expression, expectedAnswer: Seq[Product]): Unit = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - - def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { - val broadcastHashJoin = - execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right) - boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) - } - def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { - val shuffledHashJoin = - execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right) - val filteredJoin = - boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) - } + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + ExtractEquiJoinKeys.unapply(join) + } - def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = { - val sortMergeJoin = - execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right) - val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) - } + def makeBroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } - test(s"$testName using BroadcastHashJoin (build=left)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeBroadcastHashJoin(left, right, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + test(s"$testName using BroadcastHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } + } - test(s"$testName using BroadcastHashJoin (build=right)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeBroadcastHashJoin(left, right, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + test(s"$testName using BroadcastHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } + } - test(s"$testName using ShuffledHashJoin (build=left)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeShuffledHashJoin(left, right, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } + } - test(s"$testName using ShuffledHashJoin (build=right)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeShuffledHashJoin(left, right, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } + } - test(s"$testName using SortMergeJoin") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeSortMergeJoin(left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + test(s"$testName using SortMergeJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } } } - { - lazy val upperCaseData = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( - Row(1, "A"), - Row(2, "B"), - Row(3, "C"), - Row(4, "D"), - Row(5, "E"), - Row(6, "F"), - Row(null, "G") - )), new StructType().add("N", IntegerType).add("L", StringType)) - - lazy val lowerCaseData = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( - Row(1, "a"), - Row(2, "b"), - Row(3, "c"), - Row(4, "d"), - Row(null, "e") - )), new StructType().add("n", IntegerType).add("l", StringType)) + testInnerJoin( + "inner join, one match per row", + myUpperCaseData, + myLowerCaseData, + () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + { + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 1") testInnerJoin( - "inner join, one match per row", - upperCaseData, - lowerCaseData, - (upperCaseData.col("N") === lowerCaseData.col("n")).expr, + "inner join, multiple matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) ) ) } { - lazy val testData2 = Seq( - (1, 1), - (1, 2), - (2, 1), - (2, 2), - (3, 1), - (3, 2) - ).toDF("a", "b") - - { - lazy val left = testData2.where("a = 1") - lazy val right = testData2.where("a = 1") - testInnerJoin( - "inner join, multiple matches", - left, - right, - (left.col("a") === right.col("a")).expr, - Seq( - (1, 1, 1, 1), - (1, 1, 1, 2), - (1, 2, 1, 1), - (1, 2, 1, 2) - ) - ) - } - - { - lazy val left = testData2.where("a = 1") - lazy val right = testData2.where("a = 2") - testInnerJoin( - "inner join, no matches", - left, - right, - (left.col("a") === right.col("a")).expr, - Seq.empty - ) - } + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, + Seq.empty + ) } } From 1e4c32115161399f5871e5327c61e2bbec00e7ed Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 00:21:23 -0700 Subject: [PATCH 22/39] Address comments --- sql/README.md | 4 ++++ .../analysis/AnalysisErrorSuite.scala | 6 +----- .../apache/spark/sql/JavaDataFrameSuite.java | 2 +- .../apache/spark/sql/ListTablesSuite.scala | 11 +++++++--- .../apache/spark/sql/SQLContextSuite.scala | 7 +++++-- .../columnar/PartitionBatchPruningSuite.scala | 21 ++++++++----------- .../sql/execution/TungstenSortSuite.scala | 7 +++++-- .../sources/CreateTableAsSelectSuite.scala | 16 +++++++------- .../spark/sql/sources/DDLTestSuite.scala | 5 +++-- .../spark/sql/sources/FilteredScanSuite.scala | 7 +++---- .../spark/sql/sources/InsertSuite.scala | 11 ++++++---- .../spark/sql/sources/PrunedScanSuite.scala | 7 +++---- .../spark/sql/sources/SaveLoadSuite.scala | 17 +++++++-------- .../spark/sql/sources/TableScanSuite.scala | 7 +++---- .../spark/sql/test/SharedSQLContext.scala | 1 + .../sql/hive/test/SharedHiveContext.scala | 2 ++ .../spark/sql/hive/JavaDataFrameSuite.java | 2 +- .../spark/sql/hive/ErrorPositionSuite.scala | 3 ++- .../sql/hive/InsertIntoHiveTableSuite.scala | 6 +++--- .../sql/hive/execution/HiveQuerySuite.scala | 19 +++++++++-------- 20 files changed, 85 insertions(+), 76 deletions(-) diff --git a/sql/README.md b/sql/README.md index 4b8074d85585..266cb92f1b7e 100644 --- a/sql/README.md +++ b/sql/README.md @@ -61,6 +61,10 @@ import org.apache.spark.sql.execution import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.types._ +sc: org.apache.spark.SparkContext = org.apache.spark.SparkContext@27fc0441 +hc: org.apache.spark.sql.hive.HiveContext = org.apache.spark.sql.hive.HiveContext@127b5be9 +import hc.implicits._ +import hc._ Type in expressions to have them evaluated. Type :help for more information. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 63b475b6366c..f60d11c988ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.types._ @@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode { override def output: Seq[Attribute] = Nil } -class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter { +class AnalysisErrorSuite extends AnalysisTest { import TestRelations._ def errorTest( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4f630e4d5629..08922a2162a6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -34,9 +34,9 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2a80cab0bc51..147d7ca6a9cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -27,12 +27,17 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") - before { + override def beforeAll(): Unit = { + super.beforeAll() df.registerTempTable("ListTablesSuiteTable") } - after { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + override def afterAll(): Unit = { + try { + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + } finally { + afterAll() + } } test("get all tables") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 056570a0c978..007be1295077 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -23,8 +23,11 @@ import org.apache.spark.sql.test.SharedSQLContext class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(ctx) - super.afterAll() + try { + SQLContext.setLastInstantiatedContext(ctx) + } finally { + super.afterAll() + } } test("getOrCreate instantiates SQLContext") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index d029321598cc..7d366e46e79c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext with BeforeAndAfter { +class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize @@ -45,20 +45,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext wit ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - } - - override protected def afterAll(): Unit = { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - super.afterAll() - } - - before { ctx.cacheTable("pruningData") } - after { - ctx.uncacheTable("pruningData") + override protected def afterAll(): Unit = { + try { + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + ctx.uncacheTable("pruningData") + } finally { + super.afterAll() + } } // Comparisons diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 20eef186a46f..3158458edb83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -36,8 +36,11 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { } override def afterAll(): Unit = { - ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) - super.afterAll() + try { + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + } finally { + super.afterAll() + } } test("sort followed by limit") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 82e9ce954d4b..f1d793e4f3cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.sql.sources import java.io.{File, IOException} -import org.scalatest.BeforeAndAfter - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { +class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext { private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null @@ -39,12 +37,12 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jt") - super.afterAll() - } - - after { - Utils.deleteRecursively(path) + try { + caseInsensitiveContext.dropTempTable("jt") + Utils.deleteRecursively(path) + } finally { + super.afterAll() + } } test("CREATE TEMPORARY TABLE AS SELECT") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 6d9a61175c9d..7691e37d62ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -71,9 +71,10 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } } -class DDLTestSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { +class DDLTestSuite extends DataSourceTest with SharedSQLContext { - before { + override def beforeAll(): Unit = { + super.beforeAll() caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE ddlPeople diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 613467e59efb..12ba1ec6accd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.sources import scala.language.existentials -import org.scalatest.BeforeAndAfter - import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext @@ -99,9 +97,10 @@ object FiltersPushed { var list: Seq[Filter] = Nil } -class FilteredScanSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { +class FilteredScanSuite extends DataSourceTest with SharedSQLContext { - before { + override def beforeAll(): Unit = { + super.beforeAll() caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTenFiltered diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index c880d5f7df6a..925f9647d7c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -43,10 +43,13 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jsonTable") - caseInsensitiveContext.dropTempTable("jt") - Utils.deleteRecursively(path) - super.afterAll() + try { + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") + Utils.deleteRecursively(path) + } finally { + super.afterAll() + } } test("Simple INSERT OVERWRITE a JSONRelation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 6130780e8044..c5dd8aae07b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.sources import scala.language.existentials -import org.scalatest.BeforeAndAfter - import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext @@ -54,9 +52,10 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } } -class PrunedScanSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { +class PrunedScanSuite extends DataSourceTest with SharedSQLContext { - before { + override def beforeAll(): Unit = { + super.beforeAll() caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTenPruned diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index ba41e89738f3..d83278ea6d2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfter - import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { +class SaveLoadSuite extends DataSourceTest with SharedSQLContext { private lazy val sparkContext = caseInsensitiveContext.sparkContext private var originalDefaultSource: String = null private var path: File = null @@ -45,13 +43,12 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA } override def afterAll(): Unit = { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) - super.afterAll() - } - - after { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) - Utils.deleteRecursively(path) + try { + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + Utils.deleteRecursively(path) + } finally { + super.afterAll() + } } def checkLoad(expectedDF: DataFrame = df, tbl: String = "jsonTable"): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 5bde48740c90..853273679864 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.sources import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import org.scalatest.BeforeAndAfter - import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext @@ -98,7 +96,7 @@ case class AllDataTypesScan( } } -class TableScanSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { +class TableScanSuite extends DataSourceTest with SharedSQLContext { private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( s"str_$i", @@ -123,7 +121,8 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext with BeforeAnd Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) }.toSeq - before { + override def beforeAll(): Unit = { + super.beforeAll() caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTen diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index c1e5939b83a2..18e4abb35d7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -48,6 +48,7 @@ private[sql] trait SharedSQLContext extends SQLTestUtils { if (_ctx == null) { _ctx = new TestSQLContext } + // Ensure we have initialized the context before calling parent code super.beforeAll() } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala index b018693226bf..d7dd4b6882af 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala @@ -49,6 +49,8 @@ private[spark] trait SharedHiveContext extends SQLTestUtils { if (_ctx == null) { _ctx = new TestHiveContext } + // Ensure we have initialized the context before calling parent code + super.beforeAll() } /** diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index b3a88eb797a0..d93add697fd7 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -30,10 +30,10 @@ import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Window; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHiveContext; import test.org.apache.spark.sql.hive.aggregate.MyDoubleSum; -import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { private transient JavaSparkContext sc; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 7519e8bebe54..fdb9725b2578 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.{AnalysisException, QueryTest} class ErrorPositionSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { import testImplicits._ - before { + override def beforeAll(): Unit = { + super.beforeAll() Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index fdae068d3ccd..167ad9b5d8bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -29,12 +29,12 @@ import org.apache.spark.sql.test.SQLTestData.TestData import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -case class ThreeCloumntable(key: Int, value: String, key1: String) +case class ThreeColumnTable(key: Int, value: String, key1: String) class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { import testImplicits._ - private val _testData = ctx.sparkContext.parallelize( + private lazy val _testData = ctx.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { @@ -217,7 +217,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with Shared testData.registerTempTable("testData") val testDatawithNull = ctx.sparkContext.parallelize( - (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() + (1 to 10).map(i => ThreeColumnTable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() ctx.sql( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index f3a902a87e15..212abc780f03 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -22,8 +22,6 @@ import java.util.{Locale, TimeZone} import scala.util.Try -import org.scalatest.BeforeAndAfter - import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} @@ -37,7 +35,7 @@ import org.apache.spark.sql.test.SQLTestData.TestData * A set of test cases expressed in Hive QL that are not covered by the tests * included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { +class HiveQuerySuite extends HiveComparisonTest { import testImplicits._ private val originalTimeZone = TimeZone.getDefault @@ -53,11 +51,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } override def afterAll(): Unit = { - ctx.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - ctx.sql("DROP TEMPORARY FUNCTION udtf_count2") - super.afterAll() + try { + ctx.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + ctx.sql("DROP TEMPORARY FUNCTION udtf_count2") + } finally { + super.afterAll() + } } test("SPARK-4908: concurrent hive native commands") { @@ -1141,4 +1142,4 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } // for SPARK-2180 test -case class HavingRow(key: Int, value: String, attr: Int) +private case class HavingRow(key: Int, value: String, attr: Int) From 94f9c77960c7c4bd8ae0623c845b6aa250324fb8 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 00:36:14 -0700 Subject: [PATCH 23/39] Revert the removal of some BeforeAndAfters Some code actually needed to be called before / after every test. We should not change the semantics in these tests. --- .../scala/org/apache/spark/sql/ListTablesSuite.scala | 11 +++-------- .../sql/columnar/PartitionBatchPruningSuite.scala | 2 -- .../spark/sql/sources/CreateTableAsSelectSuite.scala | 9 +++++++-- .../org/apache/spark/sql/sources/SaveLoadSuite.scala | 9 +++++++-- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 63381130a795..2a80cab0bc51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -27,17 +27,12 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") - override def beforeAll(): Unit = { - super.beforeAll() + before { df.registerTempTable("ListTablesSuiteTable") } - override def afterAll(): Unit = { - try { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - } finally { - super.afterAll() - } + after { + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 7d366e46e79c..fb9ff2f50325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.columnar -import org.scalatest.BeforeAndAfter - import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index f1d793e4f3cc..59e363ef8dbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.sources import java.io.{File, IOException} +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext { +class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null @@ -39,12 +41,15 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext { override def afterAll(): Unit = { try { caseInsensitiveContext.dropTempTable("jt") - Utils.deleteRecursively(path) } finally { super.afterAll() } } + after { + Utils.deleteRecursively(path) + } + test("CREATE TEMPORARY TABLE AS SELECT") { caseInsensitiveContext.sql( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index d83278ea6d2a..8c463ba8802a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.sources import java.io.File +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class SaveLoadSuite extends DataSourceTest with SharedSQLContext { +class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { private lazy val sparkContext = caseInsensitiveContext.sparkContext private var originalDefaultSource: String = null private var path: File = null @@ -45,12 +47,15 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext { override def afterAll(): Unit = { try { caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) - Utils.deleteRecursively(path) } finally { super.afterAll() } } + after { + Utils.deleteRecursively(path) + } + def checkLoad(expectedDF: DataFrame = df, tbl: String = "jsonTable"): Unit = { caseInsensitiveContext.conf.setConf( SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") From bec7d282777d7779f308941c6ce30167aa36bed5 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 01:15:49 -0700 Subject: [PATCH 24/39] Fix a hive test + minor format updates --- .../hive/execution/HiveCompatibilitySuite.scala | 2 +- .../spark/sql/hive/CachedTableSuite.scala | 2 +- .../apache/spark/sql/hive/ListTablesSuite.scala | 17 +++++++++++------ .../hive/execution/ConcurrentHiveSuite.scala | 11 ++++++----- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0cb5250dfdd7..619f0fcaa1a4 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -42,7 +42,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) override def beforeAll() { - ctx.cacheTables = true + ctx.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 99567d667b95..f8b236af7dee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.{SaveMode, AnalysisException, QueryTest} +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.sql.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.hive.test.SharedHiveContext import org.apache.spark.storage.RDDBlockId diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index d5db730d04c6..1a2b5eb01e09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -23,10 +23,11 @@ import org.apache.spark.sql.hive.test.SharedHiveContext class ListTablesSuite extends QueryTest with SharedHiveContext { import testImplicits._ - val df = + private lazy val df = ctx.sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { + super.beforeAll() // The catalog in HiveContext is a case insensitive one. ctx.catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) ctx.catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) @@ -36,11 +37,15 @@ class ListTablesSuite extends QueryTest with SharedHiveContext { } override def afterAll(): Unit = { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) - ctx.sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") - ctx.sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") - ctx.sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") + try { + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) + ctx.sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") + ctx.sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") + ctx.sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") + } finally { + super.afterAll() + } } test("get all tables of current database") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index d54dd7de5751..6bae033be915 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.hive.execution -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { +class ConcurrentHiveSuite extends SparkFunSuite { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => - val sc = new SparkContext("local", s"TestSQLContext$i", new SparkConf()) + var sc: SparkContext = null try { + sc = new SparkContext("local", s"TestSQLContext$i", new SparkConf()) val ts = new TestHiveContext(sc) ts.executeSql("SHOW TABLES").toRdd.collect() ts.executeSql("SELECT * FROM src").toRdd.collect() ts.executeSql("SHOW TABLES").toRdd.collect() } finally { - sc.stop() + if (sc != null) { + sc.stop() + } } } } From aaba277652df561b7d77583f99862d9f84fd9eec Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 14:09:06 -0700 Subject: [PATCH 25/39] Fix places where we override before / after alls The tests are inherently unstable because there are places where we either (1) never start a SQLContext before accessing it, or (2) never stop it. This commit does a comprehensive search of all of these cases and fixes them one by one. --- .../spark/sql/test/SharedSQLContext.scala | 11 ++++-- .../sql/hive/test/SharedHiveContext.scala | 11 ++++-- .../hive/HiveDataFrameAnalyticsSuite.scala | 7 +++- .../sql/hive/MetastoreDataSourcesSuite.scala | 1 + .../execution/AggregationQuerySuite.scala | 38 +++++++++++++------ .../spark/sql/hive/orc/OrcSourceSuite.scala | 8 +++- .../apache/spark/sql/hive/parquetSuites.scala | 37 +++++++++++------- 7 files changed, 75 insertions(+), 38 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 18e4abb35d7f..20ff949d6766 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -56,11 +56,14 @@ private[sql] trait SharedSQLContext extends SQLTestUtils { * Stop the underlying [[org.apache.spark.SparkContext]], if any. */ protected override def afterAll(): Unit = { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = null + try { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + } finally { + super.afterAll() } - super.afterAll() } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala index d7dd4b6882af..2fc305ad70b2 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala @@ -57,11 +57,14 @@ private[spark] trait SharedHiveContext extends SQLTestUtils { * Stop the underlying [[org.apache.spark.SparkContext]], if any. */ protected override def afterAll(): Unit = { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = null + try { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + } finally { + super.afterAll() } - super.afterAll() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 8202722cb785..33dcd8a484d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -36,8 +36,11 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with SharedHiveContext { } override def afterAll(): Unit = { - ctx.dropTempTable("mytable") - super.afterAll() + try { + ctx.dropTempTable("mytable") + } finally { + super.afterAll() + } } test("rollup") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index faac6f5a1aa2..e1d2804d15b7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -45,6 +45,7 @@ class MetastoreDataSourcesSuite var jsonFilePath: String = _ override def beforeAll(): Unit = { + super.beforeAll() jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 88803d059bea..8d1e9c92e992 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -29,6 +29,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { var originalUseAggregate2: Boolean = _ override def beforeAll(): Unit = { + super.beforeAll() originalUseAggregate2 = ctx.conf.useSqlAggregate2 ctx.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") val data1 = Seq[(Integer, Integer)]( @@ -74,10 +75,14 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { } override def afterAll(): Unit = { - ctx.sql("DROP TABLE IF EXISTS agg1") - ctx.sql("DROP TABLE IF EXISTS agg2") - ctx.dropTempTable("emptyTable") - ctx.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) + try { + ctx.sql("DROP TABLE IF EXISTS agg1") + ctx.sql("DROP TABLE IF EXISTS agg2") + ctx.dropTempTable("emptyTable") + ctx.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) + } finally { + super.afterAll() + } } test("empty table") { @@ -540,8 +545,11 @@ class SortBasedAggregationQuerySuite extends AggregationQuerySuite { } override def afterAll(): Unit = { - ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - super.afterAll() + try { + ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + } finally { + super.afterAll() + } } } @@ -556,8 +564,11 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite { } override def afterAll(): Unit = { - ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - super.afterAll() + try { + ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + } finally { + super.afterAll() + } } } @@ -566,15 +577,18 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue var originalUnsafeEnabled: Boolean = _ override def beforeAll(): Unit = { + super.beforeAll() originalUnsafeEnabled = ctx.conf.unsafeEnabled ctx.setConf(SQLConf.UNSAFE_ENABLED.key, "true") - super.beforeAll() } override def afterAll(): Unit = { - super.afterAll() - ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - ctx.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") + try { + ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + ctx.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") + } finally { + super.afterAll() + } } override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index a3abe8882d87..641aafb24884 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -64,8 +64,12 @@ abstract class OrcSuite extends QueryTest with SharedHiveContext { } override def afterAll(): Unit = { - orcTableDir.delete() - orcTableAsDir.delete() + try { + orcTableDir.delete() + orcTableAsDir.delete() + } finally { + super.afterAll() + } } test("create temporary orc table") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 204eed161e1d..65c7387c12a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -175,15 +175,19 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { - dropTables("partitioned_parquet", - "partitioned_parquet_with_key", - "partitioned_parquet_with_complextypes", - "partitioned_parquet_with_key_and_complextypes", - "normal_parquet", - "jt", - "jt_array", - "test_parquet") - ctx.setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) + try { + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") + ctx.setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) + } finally { + super.afterAll() + } } test(s"conversion is working") { @@ -694,6 +698,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext var partitionedTableDirWithKeyAndComplexTypes: File = null override def beforeAll(): Unit = { + super.beforeAll() partitionedTableDir = Utils.createTempDir() normalTableDir = Utils.createTempDir() @@ -742,11 +747,15 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext } override protected def afterAll(): Unit = { - partitionedTableDir.delete() - normalTableDir.delete() - partitionedTableDirWithKey.delete() - partitionedTableDirWithComplexTypes.delete() - partitionedTableDirWithKeyAndComplexTypes.delete() + try { + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() + } finally { + super.afterAll() + } } /** From 40959bbfeb1298063caad5565b48844ef05d8fe2 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 16:11:21 -0700 Subject: [PATCH 26/39] Fix test The sampleBy test was failing deterministically. This is because we changed the default number of cores the SparkContext uses from 2 to *. Instead we should just explicitly set the number of partitions in the test. --- .../test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 08922a2162a6..7abdd3db8034 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -231,7 +231,7 @@ public void testCovariance() { @Test public void testSampleBy() { - DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; From 4ed58c8bf1ec5b0b43c63ef82a196d6a3f967425 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 16:26:59 -0700 Subject: [PATCH 27/39] Revert a few merge-conflict-induced unintentional changes --- .../ParquetAvroCompatibilitySuite.scala | 11 -------- .../parquet/ParquetCompatibilityTest.scala | 28 ------------------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../spark/sql/test/SharedSQLContext.scala | 1 - .../execution/HiveCompatibilitySuite.scala | 23 +++++++++------ .../sql/hive/test/SharedHiveContext.scala | 1 - .../hive/ParquetHiveCompatibilitySuite.scala | 8 ------ 7 files changed, 15 insertions(+), 59 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index a2d29eac6e65..070db67d54e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -35,17 +35,6 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared import ParquetCompatibilityTest._ import testImplicits._ - override protected def beforeAll(): Unit = { - super.beforeAll() - val writer = - new AvroParquetWriter[ParquetAvroCompat]( - new Path(parquetStore.getCanonicalPath), - ParquetAvroCompat.getClassSchema) - - (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) - writer.close() - } - private def withWriter[T <: IndexedRecord] (path: String, schema: Schema) (f: AvroParquetWriter[T] => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 57b4ae1495ff..b3406729fcc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.io.File - import scala.collection.JavaConversions._ import org.apache.hadoop.fs.{Path, PathFilter} @@ -26,37 +24,11 @@ import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType import org.apache.spark.sql.QueryTest -import org.apache.spark.util.Utils /** * Helper class for testing Parquet compatibility. */ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { - - protected var parquetStore: File = _ - - /** - * Optional path to a staging subdirectory which may be created during query processing - * (Hive does this). - * Parquet files under this directory will be ignored in [[readParquetSchema()]] - * @return an optional staging directory to ignore when scanning for parquet files. - */ - protected def stagingDir: Option[String] = None - - override protected def beforeAll(): Unit = { - super.beforeAll() - parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") - parquetStore.delete() - } - - override protected def afterAll(): Unit = { - try { - Utils.deleteRecursively(parquetStore) - } finally { - super.afterAll() - } - } - protected def readParquetSchema(path: String): MessageType = { readParquetSchema(path, { path => !path.getName.startsWith("_") }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2a79202c6d49..2b2915f36b95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,9 +25,9 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 20ff949d6766..3cfd822e2a74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,7 +42,6 @@ private[sql] trait SharedSQLContext extends SQLTestUtils { /** * Initialize the [[TestSQLContext]]. - * This is a no-op if the user explicitly switched to a custom context before this is called. */ protected override def beforeAll(): Unit = { if (_ctx == null) { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 619f0fcaa1a4..abec71c29d15 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -36,12 +36,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault - private val originalColumnBatchSize = ctx.conf.columnBatchSize - private val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) override def beforeAll() { + super.beforeAll() ctx.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -55,14 +56,18 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { } override def afterAll() { - ctx.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + try { + ctx.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - // For debugging dump some statistics about how much time was spent in various optimizer rules. - logWarning(RuleExecutor.dumpTimeSpent()) + // For debugging dump some statistics about how much time was spent in various optimizer rules. + logWarning(RuleExecutor.dumpTimeSpent()) + } finally { + super.afterAll() + } } /** A list of tests deemed out of scope currently and thus completely disregarded. */ diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala index 2fc305ad70b2..b18fca6bcd92 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala @@ -43,7 +43,6 @@ private[spark] trait SharedHiveContext extends SQLTestUtils { /** * Initialize the [[TestHiveContext]]. - * This is a no-op if the user explicitly switched to a custom context before this is called. */ protected override def beforeAll(): Unit = { if (_ctx == null) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 1c59cc10f827..94b866340d33 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -32,14 +32,6 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with Shared */ private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - override protected def afterAll(): Unit = { - try { - ctx.sql("DROP TABLE parquet_compat") - } finally { - super.afterAll() - } - } - test("Read Parquet file generated by parquet-hive") { withTable("parquet_compat") { withTempPath { dir => From 0ce5638294e19526d7479d62a69876a998b6553f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 17:13:20 -0700 Subject: [PATCH 28/39] Minor updates --- .../org/apache/spark/sql/JoinSuite.scala | 4 +--- .../ParquetAvroCompatibilitySuite.scala | 3 ++- .../spark/sql/sources/DDLTestSuite.scala | 2 -- .../spark/sql/test/TestSQLContext.scala | 2 +- .../spark/sql/hive/test/TestHiveContext.scala | 22 ++++++++----------- 5 files changed, 13 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1a963f33ed60..e52a6f96b921 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterEach - import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.SharedSQLContext -class JoinSuite extends QueryTest with BeforeAndAfterEach with SharedSQLContext { +class JoinSuite extends QueryTest with SharedSQLContext { import testImplicits._ setupTestData() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 070db67d54e6..82d40e2b61a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.test.SharedSQLContext class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ - import testImplicits._ private def withWriter[T <: IndexedRecord] (path: String, schema: Schema) @@ -128,6 +127,8 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared } test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") { + import testImplicits._ + withTempPath { dir => val path = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 7691e37d62ec..59cdb3fd6cca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.sources -import org.scalatest.BeforeAndAfter - import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 091cb8d4d5e9..9a556a7ce3d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -45,7 +45,7 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel testData.loadTestData() } - object testData extends SQLTestData { + private object testData extends SQLTestData { protected override def _sqlContext: SQLContext = self } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala index 0fc07c48e378..af674f24629c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala @@ -49,9 +49,7 @@ import scala.collection.JavaConversions._ * Calling [[reset]] will delete all tables and other state in the database, leaving the database * in a "clean" state. */ -class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { - self => - +class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => import HiveContext._ import TestHiveContext._ @@ -78,7 +76,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { lazy val warehousePath = Utils.createTempDir(namePrefix = "warehouse-") - lazy val scratchDirPath = { + private lazy val scratchDirPath = { val dir = Utils.createTempDir(namePrefix = "scratch-") dir.delete() dir @@ -96,15 +94,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { ) } - val testTempDir = Utils.createTempDir() + private val testTempDir = Utils.createTempDir() // For some hive test case which contain ${system:test.tmp.dir} System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) - /** The location of the compiled hive distribution */ - lazy val hiveHome = envVarToFile("HIVE_HOME") /** The location of the hive source code. */ - lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") + private lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") // Override so we can intercept relative paths and rewrite them to point at hive. override def runSqlHive(sql: String): Seq[String] = @@ -149,12 +145,12 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { cmd } - val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") + private val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() Utils.registerShutdownDeleteDir(hiveFilesTemp) - val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { + private val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) } else { new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + @@ -169,7 +165,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { .getOrElse(new File(inRepoTests, stripped)) } - val describedTable = "DESCRIBE (\\w+)".r + private val describedTable = "DESCRIBE (\\w+)".r /** * Override QueryExecution with special debug workflow. @@ -207,7 +203,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * demand when a query are run against it. */ @transient - lazy val testTables = new mutable.HashMap[String, TestTable]() + private lazy val testTables = new mutable.HashMap[String, TestTable]() def registerTestTable(testTable: TestTable): Unit = { testTables += (testTable.name -> testTable) @@ -217,7 +213,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java // https://github.com/apache/hive/blob/branch-0.13/data/scripts/q_test_init.sql @transient - val hiveQTestUtilTables = Seq( + private val hiveQTestUtilTables = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), From 814df2fdc4d111c9d20319252f8e912ed64aac13 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 17:25:35 -0700 Subject: [PATCH 29/39] Add back singletons but deprecate them --- .../spark/sql/test/TestSQLContext.scala | 10 ++++++ .../apache/spark/sql/hive/test/TestHive.scala | 34 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 9a556a7ce3d4..d34de0c62211 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.{SQLConf, SQLContext} + /** * A special [[SQLContext]] prepared for testing. */ @@ -49,3 +50,12 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel protected override def _sqlContext: SQLContext = self } } + +// Note: this should NOT be used for internal Spark unit tests because the singleton makes it +// very difficult to start a SQLContext with a custom underlying SparkContext (SPARK-9580). +@deprecated("instantiate new TestSQLContext instead of using this singleton", "1.5.0") +object TestSQLContext extends SQLContext( + new SparkContext( + "local[2]", + "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala new file mode 100644 index 000000000000..cdb9eb270af7 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.apache.spark.{SparkConf, SparkContext} + +// Note: this should NOT be used for internal Spark unit tests because the singleton makes it +// very difficult to start a HiveContext with a custom underlying SparkContext (SPARK-9580). +@deprecated("instantiate new TestHiveContext instead of using this singleton", "1.5.0") +object TestHive extends TestHiveContext( + new SparkContext( + System.getProperty("spark.sql.test.master", "local[32]"), + "TestSQLContext", + new SparkConf() + .set("spark.sql.test", "") + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) From 451fa37bb8324a5b93d8bbc31038f26a4c12b2f3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 17:32:31 -0700 Subject: [PATCH 30/39] Use consistent name for added back singletons --- .../spark/sql/test/TestSQLContext.scala | 16 +++++---- .../spark/sql/hive/test/TestHiveContext.scala | 31 +++++++++++------ .../apache/spark/sql/hive/test/TestHive.scala | 34 ------------------- 3 files changed, 29 insertions(+), 52 deletions(-) delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index d34de0c62211..6e4c16f16f5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -27,8 +27,7 @@ import org.apache.spark.sql.{SQLConf, SQLContext} private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => def this() { - this(new SparkContext("local[2]", "test-sql-context", - new SparkConf().set("spark.sql.testkey", "true"))) + this(TestSQLContext.defaultSparkContext()) } // Use fewer partitions to speed up testing @@ -51,11 +50,14 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel } } +private[sql] object TestSQLContext { + def defaultSparkContext(): SparkContext = { + new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true")) + } +} + // Note: this should NOT be used for internal Spark unit tests because the singleton makes it // very difficult to start a SQLContext with a custom underlying SparkContext (SPARK-9580). @deprecated("instantiate new TestSQLContext instead of using this singleton", "1.5.0") -object TestSQLContext extends SQLContext( - new SparkContext( - "local[2]", - "test-sql-context", - new SparkConf().set("spark.sql.testkey", "true"))) +object SingletonTestSQLContext extends TestSQLContext(TestSQLContext.defaultSparkContext()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala index 22073a856980..2c8114fecb2a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala @@ -54,17 +54,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => import TestHiveContext._ def this() { - this(new SparkContext( - System.getProperty("spark.sql.test.master", "local[32]"), - "TestSQLContext", - new SparkConf() - // SPARK-3729: Test key required to check for initialization errors with config. - .set("spark.sql.test", "") - .set("spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe") - .set("spark.buffer.pageSize", "4m") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) + this(defaultSparkContext()) } // By clearing the port we force Spark to pick a new one. This allows us to rerun tests @@ -456,4 +446,23 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => private[hive] object TestHiveContext { case class TestTable(name: String, commands: (() => Unit)*) + + def defaultSparkContext(): SparkContext = { + new SparkContext( + System.getProperty("spark.sql.test.master", "local[32]"), + "TestSQLContext", + new SparkConf() + // SPARK-3729: Test key required to check for initialization errors with config. + .set("spark.sql.test", "") + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + .set("spark.buffer.pageSize", "4m") + // SPARK-8910 + .set("spark.ui.enabled", "false")) + } } + +// Note: this should NOT be used for internal Spark unit tests because the singleton makes it +// very difficult to start a HiveContext with a custom underlying SparkContext (SPARK-9580). +@deprecated("instantiate new TestHiveContext instead of using this singleton", "1.5.0") +object GlobalTestHiveContext extends TestHiveContext(TestHiveContext.defaultSparkContext()) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala deleted file mode 100644 index cdb9eb270af7..000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - -import org.apache.spark.{SparkConf, SparkContext} - -// Note: this should NOT be used for internal Spark unit tests because the singleton makes it -// very difficult to start a HiveContext with a custom underlying SparkContext (SPARK-9580). -@deprecated("instantiate new TestHiveContext instead of using this singleton", "1.5.0") -object TestHive extends TestHiveContext( - new SparkContext( - System.getProperty("spark.sql.test.master", "local[32]"), - "TestSQLContext", - new SparkConf() - .set("spark.sql.test", "") - .set("spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) From 9ea7f7c946360e89b1b6865418429306317e4783 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 17:41:54 -0700 Subject: [PATCH 31/39] Fix style --- .../spark/sql/hive/test/TestHiveContext.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala index 2c8114fecb2a..2007b8c4f98e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala @@ -140,12 +140,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => hiveFilesTemp.mkdir() ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) - private val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { - new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) - } else { - new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + - File.separator + "resources") - } + private val inRepoTests = + if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { + new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) + } else { + new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + + File.separator + "resources") + } def getHiveFile(path: String): File = { val stripped = path.replaceAll("""\.\.\/""", "").replace('/', File.separatorChar) From ddc4b057e664ba7e4870f96268bd85382fe45a6f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 18:46:04 -0700 Subject: [PATCH 32/39] Revert "Use consistent name for added back singletons" This reverts commit 451fa37bb8324a5b93d8bbc31038f26a4c12b2f3. --- .../spark/sql/test/TestSQLContext.scala | 16 ++++----- .../spark/sql/hive/test/TestHiveContext.scala | 31 ++++++----------- .../apache/spark/sql/hive/test/TestHive.scala | 34 +++++++++++++++++++ 3 files changed, 52 insertions(+), 29 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 6e4c16f16f5f..d34de0c62211 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.{SQLConf, SQLContext} private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => def this() { - this(TestSQLContext.defaultSparkContext()) + this(new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) } // Use fewer partitions to speed up testing @@ -50,14 +51,11 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel } } -private[sql] object TestSQLContext { - def defaultSparkContext(): SparkContext = { - new SparkContext("local[2]", "test-sql-context", - new SparkConf().set("spark.sql.testkey", "true")) - } -} - // Note: this should NOT be used for internal Spark unit tests because the singleton makes it // very difficult to start a SQLContext with a custom underlying SparkContext (SPARK-9580). @deprecated("instantiate new TestSQLContext instead of using this singleton", "1.5.0") -object SingletonTestSQLContext extends TestSQLContext(TestSQLContext.defaultSparkContext()) +object TestSQLContext extends SQLContext( + new SparkContext( + "local[2]", + "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala index 2007b8c4f98e..f0bffcbc0fff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala @@ -54,7 +54,17 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => import TestHiveContext._ def this() { - this(defaultSparkContext()) + this(new SparkContext( + System.getProperty("spark.sql.test.master", "local[32]"), + "TestSQLContext", + new SparkConf() + // SPARK-3729: Test key required to check for initialization errors with config. + .set("spark.sql.test", "") + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + .set("spark.buffer.pageSize", "4m") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) } // By clearing the port we force Spark to pick a new one. This allows us to rerun tests @@ -447,23 +457,4 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => private[hive] object TestHiveContext { case class TestTable(name: String, commands: (() => Unit)*) - - def defaultSparkContext(): SparkContext = { - new SparkContext( - System.getProperty("spark.sql.test.master", "local[32]"), - "TestSQLContext", - new SparkConf() - // SPARK-3729: Test key required to check for initialization errors with config. - .set("spark.sql.test", "") - .set("spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe") - .set("spark.buffer.pageSize", "4m") - // SPARK-8910 - .set("spark.ui.enabled", "false")) - } } - -// Note: this should NOT be used for internal Spark unit tests because the singleton makes it -// very difficult to start a HiveContext with a custom underlying SparkContext (SPARK-9580). -@deprecated("instantiate new TestHiveContext instead of using this singleton", "1.5.0") -object GlobalTestHiveContext extends TestHiveContext(TestHiveContext.defaultSparkContext()) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala new file mode 100644 index 000000000000..cdb9eb270af7 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.apache.spark.{SparkConf, SparkContext} + +// Note: this should NOT be used for internal Spark unit tests because the singleton makes it +// very difficult to start a HiveContext with a custom underlying SparkContext (SPARK-9580). +@deprecated("instantiate new TestHiveContext instead of using this singleton", "1.5.0") +object TestHive extends TestHiveContext( + new SparkContext( + System.getProperty("spark.sql.test.master", "local[32]"), + "TestSQLContext", + new SparkConf() + .set("spark.sql.test", "") + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) From ece3a81cf82f71f64910b30dcb7b9a2dff395703 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 18:48:08 -0700 Subject: [PATCH 33/39] Remove TestSQLContext, but keep TestHive per @marmbrus' request. --- .../spark/sql/test/TestSQLContext.scala | 9 ------- .../spark/sql/hive/test/TestHiveContext.scala | 24 ++++++++++--------- .../apache/spark/sql/hive/test/TestHive.scala | 14 ++--------- 3 files changed, 15 insertions(+), 32 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index d34de0c62211..92ef2f7d74ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -50,12 +50,3 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel protected override def _sqlContext: SQLContext = self } } - -// Note: this should NOT be used for internal Spark unit tests because the singleton makes it -// very difficult to start a SQLContext with a custom underlying SparkContext (SPARK-9580). -@deprecated("instantiate new TestSQLContext instead of using this singleton", "1.5.0") -object TestSQLContext extends SQLContext( - new SparkContext( - "local[2]", - "test-sql-context", - new SparkConf().set("spark.sql.testkey", "true"))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala index f0bffcbc0fff..1900967bd991 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala @@ -54,17 +54,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => import TestHiveContext._ def this() { - this(new SparkContext( - System.getProperty("spark.sql.test.master", "local[32]"), - "TestSQLContext", - new SparkConf() - // SPARK-3729: Test key required to check for initialization errors with config. - .set("spark.sql.test", "") - .set("spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe") - .set("spark.buffer.pageSize", "4m") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) + this(TestHiveContext.defaultSparkContext()) } // By clearing the port we force Spark to pick a new one. This allows us to rerun tests @@ -457,4 +447,16 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => private[hive] object TestHiveContext { case class TestTable(name: String, commands: (() => Unit)*) + + def defaultSparkContext(): SparkContext = { + new SparkContext( + System.getProperty("spark.sql.test.master", "local[32]"), + "TestSQLContext", + new SparkConf() + .set("spark.sql.test", "") + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + // SPARK-8910 + .set("spark.ui.enabled", "false")) + } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala index cdb9eb270af7..0556fbb3fa4c 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala @@ -17,18 +17,8 @@ package org.apache.spark.sql.hive.test -import org.apache.spark.{SparkConf, SparkContext} // Note: this should NOT be used for internal Spark unit tests because the singleton makes it // very difficult to start a HiveContext with a custom underlying SparkContext (SPARK-9580). -@deprecated("instantiate new TestHiveContext instead of using this singleton", "1.5.0") -object TestHive extends TestHiveContext( - new SparkContext( - System.getProperty("spark.sql.test.master", "local[32]"), - "TestSQLContext", - new SparkConf() - .set("spark.sql.test", "") - .set("spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) +@deprecated("instantiate new TestHiveContext instead", "1.5.0") +object TestHive extends TestHiveContext(TestHiveContext.defaultSparkContext()) From 8d69bf859d01d9c85494f0c5f370e109b2a64b37 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 19:29:58 -0700 Subject: [PATCH 34/39] Fix test --- .../spark/sql/sources/JsonHadoopFsRelationSuite.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index ed6d512ab36f..0a823bcaa178 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -26,8 +26,6 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" - import sqlContext._ - test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) @@ -36,7 +34,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext + ctx.sparkContext .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") .saveAsTextFile(partitionDir.toString) } @@ -45,7 +43,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + ctx.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -63,14 +61,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { val data = Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil - val df = createDataFrame(sparkContext.parallelize(data), schema) + val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) // Write the data out. df.write.format(dataSourceName).save(file.getCanonicalPath) // Read it back and check the result. checkAnswer( - read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + ctx.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), df ) } From b58ae73343e823819ad7d860d34f8056881cd3fb Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 22:22:51 -0700 Subject: [PATCH 35/39] Fix tests We were getting NPEs in HiveCompatibilitySuite and subclasses because we referenced a SQLContext in the constructor. This is because we were using a helper method that was part of the TestHiveContext, which was not initialized yet. This commit moves the helper method out to a static object. The method does not inherently have to be an instance method of the TestHiveContext... --- .../HashJoinCompatibilitySuite.scala | 7 ++- .../execution/HiveCompatibilitySuite.scala | 7 ++- .../HiveWindowFunctionQuerySuite.scala | 5 +- .../spark/sql/hive/test/TestHiveContext.scala | 54 +++++++++---------- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 4 +- .../sql/hive/execution/HiveQuerySuite.scala | 9 ++-- .../sql/hive/execution/HiveSerDeSuite.scala | 6 ++- .../sql/hive/execution/SQLQuerySuite.scala | 7 +-- 8 files changed, 55 insertions(+), 44 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala index 5fefce41a2ba..58d8cb042cd6 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala @@ -31,8 +31,11 @@ class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { } override def afterAll() { - ctx.setConf(SQLConf.SORTMERGE_JOIN, true) - super.afterAll() + try { + ctx.setConf(SQLConf.SORTMERGE_JOIN, true) + } finally { + super.afterAll() + } } override def whiteList = Seq( diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index abec71c29d15..a726de34cd91 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.hive.test.TestHiveContext /** * Runs the test cases that are included in the hive distribution. @@ -31,7 +32,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath - private lazy val hiveQueryDir = ctx.getHiveFile( + private lazy val hiveQueryDir = TestHiveContext.getHiveFile( "ql/src/test/queries/clientpositive".split("/").mkString(File.separator)) private val originalTimeZone = TimeZone.getDefault @@ -39,7 +40,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning - def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + override def testCases: Seq[(String, File)] = { + hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + } override def beforeAll() { super.beforeAll() diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 7bd13f437227..aed6a72423e7 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -22,6 +22,7 @@ import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.util.Utils /** @@ -58,7 +59,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte | p_retailprice DOUBLE, | p_comment STRING) """.stripMargin) - val testData1 = ctx.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + val testData1 = TestHiveContext.getHiveFile("data/files/part_tiny.txt").getCanonicalPath ctx.sql( s""" |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part @@ -82,7 +83,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |row format delimited |fields terminated by '|' """.stripMargin) - val testData2 = ctx.getHiveFile("data/files/over1k").getCanonicalPath + val testData2 = TestHiveContext.getHiveFile("data/files/over1k").getCanonicalPath ctx.sql( s""" |LOAD DATA LOCAL INPATH '$testData2' overwrite into table over1k diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala index 1900967bd991..1c31bb10831a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala @@ -89,9 +89,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => // For some hive test case which contain ${system:test.tmp.dir} System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) - /** The location of the hive source code. */ - private lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") - // Override so we can intercept relative paths and rewrite them to point at hive. override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(substitutor.substitute(this.hiveconf, sql))) @@ -114,14 +111,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => } } - /** - * Returns the value of specified environmental variable as a [[java.io.File]] after checking - * to ensure it exists - */ - private def envVarToFile(envVar: String): Option[File] = { - Option(System.getenv(envVar)).map(new File(_)) - } - /** * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the * hive test cases assume the system is set up. @@ -140,22 +129,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => hiveFilesTemp.mkdir() ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) - private val inRepoTests = - if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { - new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) - } else { - new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + - File.separator + "resources") - } - - def getHiveFile(path: String): File = { - val stripped = path.replaceAll("""\.\.\/""", "").replace('/', File.separatorChar) - hiveDevHome - .map(new File(_, stripped)) - .filter(_.exists) - .getOrElse(new File(inRepoTests, stripped)) - } - private val describedTable = "DESCRIBE (\\w+)".r /** @@ -459,4 +432,31 @@ private[hive] object TestHiveContext { // SPARK-8910 .set("spark.ui.enabled", "false")) } + + def getHiveFile(path: String): File = { + val stripped = path.replaceAll("""\.\.\/""", "").replace('/', File.separatorChar) + hiveDevHome + .map(new File(_, stripped)) + .filter(_.exists) + .getOrElse(new File(inRepoTests, stripped)) + } + + /** + * Returns the value of specified environmental variable as a [[java.io.File]] after checking + * to ensure it exists + */ + private def envVarToFile(envVar: String): Option[File] = { + Option(System.getenv(envVar)).map(new File(_)) + } + + /** The location of the hive source code. */ + private lazy val hiveDevHome: Option[File] = envVarToFile("HIVE_DEV_HOME") + + private lazy val inRepoTests: File = + if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { + new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) + } else { + new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + + File.separator + "resources") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index aa1705019933..3003c3c19f96 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -47,8 +47,8 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = ctx.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = ctx.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = TestHiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() + val jar4 = TestHiveContext.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 212abc780f03..0ccc95227ab6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.test.SQLTestData.TestData @@ -876,7 +877,7 @@ class HiveQuerySuite extends HiveComparisonTest { } test("ADD JAR command") { - val testJar = ctx.getHiveFile("data/files/TestSerDe.jar").getCanonicalPath + val testJar = TestHiveContext.getHiveFile("data/files/TestSerDe.jar").getCanonicalPath ctx.sql("CREATE TABLE alter1(a INT, b INT)") intercept[Exception] { ctx.sql( @@ -889,8 +890,8 @@ class HiveQuerySuite extends HiveComparisonTest { test("ADD JAR command 2") { // this is a test case from mapjoin_addjar.q - val testJar = ctx.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath - val testData = ctx.getHiveFile("data/files/sample.json").getCanonicalPath + val testJar = TestHiveContext.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val testData = TestHiveContext.getHiveFile("data/files/sample.json").getCanonicalPath ctx.sql(s"ADD JAR $testJar") ctx.sql( """CREATE TABLE t1(a string, b string) @@ -901,7 +902,7 @@ class HiveQuerySuite extends HiveComparisonTest { } test("ADD FILE command") { - val testFile = ctx.getHiveFile("data/files/v1.txt").getCanonicalFile + val testFile = TestHiveContext.getHiveFile("data/files/v1.txt").getCanonicalFile ctx.sql(s"ADD FILE $testFile") val checkAddFileRDD = ctx.sparkContext.parallelize(1 to 2, 1).mapPartitions { _ => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index ece3715e19f6..2d784f8062a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.hive.test.TestHiveContext + /** * A set of tests that validates support for Hive SerDe. */ @@ -30,8 +32,8 @@ class HiveSerDeSuite extends HiveComparisonTest { |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") """.stripMargin) - ctx.sql( - s"LOAD DATA LOCAL INPATH '${ctx.getHiveFile("data/files/sales.txt")}' INTO TABLE sales") + val dataFile = TestHiveContext.getHiveFile("data/files/sales.txt") + ctx.sql(s"LOAD DATA LOCAL INPATH '$dataFile' INTO TABLE sales") } // table sales is not a cache table, and will be clear after reset diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 4fd877970b05..75eb0c4c8016 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.{SharedHiveContext, TestHiveContext} import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestData.TestData @@ -67,7 +67,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { import testImplicits._ test("UDTF") { - ctx.sql(s"ADD JAR ${ctx.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + ctx.sql(s"ADD JAR ${TestHiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") // The function source code can be found at: // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF ctx.sql( @@ -1044,7 +1044,8 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { val thread = new Thread { override def run() { // To make sure this test works, this jar should not be loaded in another place. - ctx.sql(s"ADD JAR ${ctx.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + val jar = TestHiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() + ctx.sql(s"ADD JAR $jar") try { ctx.sql( """ From 48af8e4cf24170afc537eb4b430c6c8a0feb3d5c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 13 Aug 2015 09:33:11 -0700 Subject: [PATCH 36/39] Fix another before / after alls --- .../hive/thriftserver/UISeleniumSuite.scala | 2 -- .../HiveWindowFunctionQuerySuite.scala | 26 +++++++++++++------ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 806240e6de45..bf431cd6b026 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.hive.HiveContext import org.apache.spark.ui.SparkUICssErrorHandler class UISeleniumSuite @@ -36,7 +35,6 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ var server: HiveThriftServer2 = _ - var hc: HiveContext = _ val uiPort = 20000 + Random.nextInt(10000) override def mode: ServerMode.Value = ServerMode.binary diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index aed6a72423e7..866fc07170ca 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -38,6 +38,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte private val testTempDir = Utils.createTempDir() override def beforeAll() { + super.beforeAll() ctx.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -100,10 +101,14 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte } override def afterAll() { - ctx.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - ctx.reset() + try { + ctx.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + ctx.reset() + } finally { + super.afterAll() + } } ///////////////////////////////////////////////////////////////////////////// @@ -766,6 +771,7 @@ class HiveWindowFunctionQueryFileSuite private val testTempDir = Utils.createTempDir() override def beforeAll() { + super.beforeAll() ctx.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -783,10 +789,14 @@ class HiveWindowFunctionQueryFileSuite } override def afterAll() { - ctx.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - ctx.reset() + try { + ctx.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + ctx.reset() + } finally { + super.afterAll() + } } override def blackList: Seq[String] = Seq( From f599bbc19b7c80d651156082821082b46700430a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 13 Aug 2015 12:06:14 -0700 Subject: [PATCH 37/39] Revert all Hive related changes This is because setting up multiple HiveContext's makes tests really unstable. We keep running out of PermGen space and run into JVM seg faults once in a while. This reduces the size of the patch significantly and only deals with the singleton SQLContext. --- project/SparkBuild.scala | 8 +- sql/README.md | 5 +- .../HashJoinCompatibilitySuite.scala | 10 +- .../execution/HiveCompatibilitySuite.scala | 40 +- .../HiveWindowFunctionQuerySuite.scala | 62 +-- .../{TestHiveContext.scala => TestHive.scala} | 113 ++-- .../sql/hive/test/SharedHiveContext.scala | 69 --- .../apache/spark/sql/hive/test/TestHive.scala | 24 - .../spark/sql/hive/JavaDataFrameSuite.java | 5 +- .../hive/JavaMetastoreDataSourcesSuite.java | 4 +- .../spark/sql/hive/CachedTableSuite.scala | 145 ++--- .../spark/sql/hive/ErrorPositionSuite.scala | 11 +- .../hive/HiveDataFrameAnalyticsSuite.scala | 40 +- .../sql/hive/HiveDataFrameJoinSuite.scala | 5 +- .../sql/hive/HiveDataFrameWindowSuite.scala | 24 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 41 +- .../spark/sql/hive/HiveParquetSuite.scala | 37 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 15 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 143 ++--- .../spark/sql/hive/ListTablesSuite.scala | 46 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 253 ++++----- .../spark/sql/hive/MultiDatabaseSuite.scala | 83 +-- .../hive/ParquetHiveCompatibilitySuite.scala | 52 +- .../spark/sql/hive/QueryPartitionSuite.scala | 29 +- .../spark/sql/hive/SerializationSuite.scala | 10 +- .../spark/sql/hive/StatisticsSuite.scala | 54 +- .../org/apache/spark/sql/hive/UDFSuite.scala | 8 +- .../execution/AggregationQuerySuite.scala | 136 +++-- .../execution/BigDataBenchmarkSuite.scala | 42 +- .../hive/execution/ConcurrentHiveSuite.scala | 20 +- .../hive/execution/HiveComparisonTest.scala | 36 +- .../sql/hive/execution/HiveExplainSuite.scala | 30 +- .../HiveOperatorQueryableSuite.scala | 15 +- .../sql/hive/execution/HivePlanTest.scala | 11 +- .../sql/hive/execution/HiveQuerySuite.scala | 308 ++++++----- .../hive/execution/HiveResolutionSuite.scala | 23 +- .../sql/hive/execution/HiveSerDeSuite.scala | 19 +- .../hive/execution/HiveTableScanSuite.scala | 34 +- .../execution/HiveTypeCoercionSuite.scala | 3 +- .../sql/hive/execution/HiveUDFSuite.scala | 162 +++--- .../sql/hive/execution/PruningSuite.scala | 24 +- .../sql/hive/execution/SQLQuerySuite.scala | 506 +++++++++--------- .../execution/ScriptTransformationSuite.scala | 15 +- .../hive/orc/OrcHadoopFsRelationSuite.scala | 9 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 60 ++- .../spark/sql/hive/orc/OrcQuerySuite.scala | 70 +-- .../spark/sql/hive/orc/OrcSourceSuite.scala | 52 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 15 +- .../apache/spark/sql/hive/parquetSuites.scala | 275 +++++----- .../CommitFailureTestRelationSuite.scala | 10 +- .../sources/JsonHadoopFsRelationSuite.scala | 10 +- .../ParquetHadoopFsRelationSuite.scala | 19 +- .../SimpleTextHadoopFsRelationSuite.scala | 7 +- .../sql/sources/hadoopFsRelationSuites.scala | 58 +- 54 files changed, 1601 insertions(+), 1704 deletions(-) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/test/{TestHiveContext.scala => TestHive.scala} (93%) delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 38872e296692..109722210884 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -365,14 +365,10 @@ object Hive { |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.hive._ + |import org.apache.spark.sql.hive.test.TestHive._ |import org.apache.spark.sql.types._ - | - |val sc = new SparkContext("local[*]", "dev-shell") - |val hc = new HiveContext(sc) - |import hc.implicits._ - |import hc._ """.stripMargin, - cleanupCommands in console := "sc.stop()", + cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce // in order to generate golden files. This is only required for developers who are adding new // new query tests. diff --git a/sql/README.md b/sql/README.md index 266cb92f1b7e..63d4dac9829e 100644 --- a/sql/README.md +++ b/sql/README.md @@ -60,11 +60,8 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.types._ -sc: org.apache.spark.SparkContext = org.apache.spark.SparkContext@27fc0441 -hc: org.apache.spark.sql.hive.HiveContext = org.apache.spark.sql.hive.HiveContext@127b5be9 -import hc.implicits._ -import hc._ Type in expressions to have them evaluated. Type :help for more information. diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala index 58d8cb042cd6..1a5ba20404c4 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution with hash joins. @@ -27,15 +28,12 @@ import org.apache.spark.sql.SQLConf class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() - ctx.setConf(SQLConf.SORTMERGE_JOIN, false) + TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) } override def afterAll() { - try { - ctx.setConf(SQLConf.SORTMERGE_JOIN, true) - } finally { - super.afterAll() - } + TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) + super.afterAll() } override def whiteList = Seq( diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index a726de34cd91..ab309e0a1d36 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -20,57 +20,49 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { - // TODO: bundle in jar files... get from classpath - private lazy val hiveQueryDir = TestHiveContext.getHiveFile( + private lazy val hiveQueryDir = TestHive.getHiveFile( "ql/src/test/queries/clientpositive".split("/").mkString(File.separator)) private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault - private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning + private val originalColumnBatchSize = TestHive.conf.columnBatchSize + private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - override def testCases: Seq[(String, File)] = { - hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) - } + def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) override def beforeAll() { - super.beforeAll() - ctx.cacheTables = true + TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) // Set a relatively small column batch size for testing purposes - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) RuleExecutor.resetTime() } override def afterAll() { - try { - ctx.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - // For debugging dump some statistics about how much time was spent in various optimizer rules. - logWarning(RuleExecutor.dumpTimeSpent()) - } finally { - super.afterAll() - } + // For debugging dump some statistics about how much time was spent in various optimizer rules. + logWarning(RuleExecutor.dumpTimeSpent()) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 866fc07170ca..92bb9e6d73af 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -22,7 +22,8 @@ import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.util.Utils /** @@ -32,22 +33,20 @@ import org.apache.spark.util.Utils * files, every `createQueryTest` calls should explicitly set `reset` to `false`. */ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfter { - private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault private val testTempDir = Utils.createTempDir() override def beforeAll() { - super.beforeAll() - ctx.cacheTables = true + TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) // Create the table used in windowing.q - ctx.sql("DROP TABLE IF EXISTS part") - ctx.sql( + sql("DROP TABLE IF EXISTS part") + sql( """ |CREATE TABLE part( | p_partkey INT, @@ -60,14 +59,14 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte | p_retailprice DOUBLE, | p_comment STRING) """.stripMargin) - val testData1 = TestHiveContext.getHiveFile("data/files/part_tiny.txt").getCanonicalPath - ctx.sql( + val testData1 = TestHive.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + sql( s""" |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part """.stripMargin) - ctx.sql("DROP TABLE IF EXISTS over1k") - ctx.sql( + sql("DROP TABLE IF EXISTS over1k") + sql( """ |create table over1k( | t tinyint, @@ -84,8 +83,8 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |row format delimited |fields terminated by '|' """.stripMargin) - val testData2 = TestHiveContext.getHiveFile("data/files/over1k").getCanonicalPath - ctx.sql( + val testData2 = TestHive.getHiveFile("data/files/over1k").getCanonicalPath + sql( s""" |LOAD DATA LOCAL INPATH '$testData2' overwrite into table over1k """.stripMargin) @@ -93,22 +92,18 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte // The following settings are used for generating golden files with Hive. // We have to use kryo to correctly let Hive serialize plans with window functions. // This is used to generate golden files. - ctx.sql("set hive.plan.serialization.format=kryo") + sql("set hive.plan.serialization.format=kryo") // Explicitly set fs to local fs. - ctx.sql(s"set fs.default.name=file://$testTempDir/") + sql(s"set fs.default.name=file://$testTempDir/") // Ask Hive to run jobs in-process as a single map and reduce task. - ctx.sql("set mapred.job.tracker=local") + sql("set mapred.job.tracker=local") } override def afterAll() { - try { - ctx.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - ctx.reset() - } finally { - super.afterAll() - } + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.reset() } ///////////////////////////////////////////////////////////////////////////// @@ -771,8 +766,7 @@ class HiveWindowFunctionQueryFileSuite private val testTempDir = Utils.createTempDir() override def beforeAll() { - super.beforeAll() - ctx.cacheTables = true + TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -781,22 +775,18 @@ class HiveWindowFunctionQueryFileSuite // The following settings are used for generating golden files with Hive. // We have to use kryo to correctly let Hive serialize plans with window functions. // This is used to generate golden files. - // ctx.sql("set hive.plan.serialization.format=kryo") + // sql("set hive.plan.serialization.format=kryo") // Explicitly set fs to local fs. - // ctx.sql(s"set fs.default.name=file://$testTempDir/") + // sql(s"set fs.default.name=file://$testTempDir/") // Ask Hive to run jobs in-process as a single map and reduce task. - // ctx.sql("set mapred.job.tracker=local") + // sql("set mapred.job.tracker=local") } override def afterAll() { - try { - ctx.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - ctx.reset() - } finally { - super.afterAll() - } + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.reset() } override def blackList: Seq[String] = Seq( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala similarity index 93% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 1c31bb10831a..4eae699ac3b5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -38,24 +37,39 @@ import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.{SparkConf, SparkContext} /* Implicit conversions */ import scala.collection.JavaConversions._ +// SPARK-3729: Test key required to check for initialization errors with config. +object TestHive + extends TestHiveContext( + new SparkContext( + System.getProperty("spark.sql.test.master", "local[32]"), + "TestSQLContext", + new SparkConf() + .set("spark.sql.test", "") + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) + /** * A locally running test instance of Spark's Hive execution engine. * * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables. * Calling [[reset]] will delete all tables and other state in the database, leaving the database * in a "clean" state. + * + * TestHive is singleton object version of this class because instantiating multiple copies of the + * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of + * test cases that rely on TestHive must be serialized. */ -class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => - import HiveContext._ - import TestHiveContext._ +class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { + self => - def this() { - this(TestHiveContext.defaultSparkContext()) - } + import HiveContext._ // By clearing the port we force Spark to pick a new one. This allows us to rerun tests // without restarting the JVM. @@ -66,7 +80,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => lazy val warehousePath = Utils.createTempDir(namePrefix = "warehouse-") - private lazy val scratchDirPath = { + lazy val scratchDirPath = { val dir = Utils.createTempDir(namePrefix = "scratch-") dir.delete() dir @@ -84,11 +98,16 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => ) } - private val testTempDir = Utils.createTempDir() + val testTempDir = Utils.createTempDir() // For some hive test case which contain ${system:test.tmp.dir} System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) + /** The location of the compiled hive distribution */ + lazy val hiveHome = envVarToFile("HIVE_HOME") + /** The location of the hive source code. */ + lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") + // Override so we can intercept relative paths and rewrite them to point at hive. override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(substitutor.substitute(this.hiveconf, sql))) @@ -111,6 +130,14 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => } } + /** + * Returns the value of specified environmental variable as a [[java.io.File]] after checking + * to ensure it exists + */ + private def envVarToFile(envVar: String): Option[File] = { + Option(System.getenv(envVar)).map(new File(_)) + } + /** * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the * hive test cases assume the system is set up. @@ -124,12 +151,27 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => cmd } - private val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") + val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) - private val describedTable = "DESCRIBE (\\w+)".r + val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { + new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) + } else { + new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + + File.separator + "resources") + } + + def getHiveFile(path: String): File = { + val stripped = path.replaceAll("""\.\.\/""", "").replace('/', File.separatorChar) + hiveDevHome + .map(new File(_, stripped)) + .filter(_.exists) + .getOrElse(new File(inRepoTests, stripped)) + } + + val describedTable = "DESCRIBE (\\w+)".r /** * Override QueryExecution with special debug workflow. @@ -156,6 +198,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => } } + case class TestTable(name: String, commands: (() => Unit)*) + protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { () => new QueryExecution(sql).stringResult(): Unit @@ -167,7 +211,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => * demand when a query are run against it. */ @transient - private lazy val testTables = new mutable.HashMap[String, TestTable]() + lazy val testTables = new mutable.HashMap[String, TestTable]() def registerTestTable(testTable: TestTable): Unit = { testTables += (testTable.name -> testTable) @@ -177,7 +221,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java // https://github.com/apache/hive/blob/branch-0.13/data/scripts/q_test_init.sql @transient - private val hiveQTestUtilTables = Seq( + val hiveQTestUtilTables = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), @@ -417,46 +461,3 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => } } } - -private[hive] object TestHiveContext { - case class TestTable(name: String, commands: (() => Unit)*) - - def defaultSparkContext(): SparkContext = { - new SparkContext( - System.getProperty("spark.sql.test.master", "local[32]"), - "TestSQLContext", - new SparkConf() - .set("spark.sql.test", "") - .set("spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe") - // SPARK-8910 - .set("spark.ui.enabled", "false")) - } - - def getHiveFile(path: String): File = { - val stripped = path.replaceAll("""\.\.\/""", "").replace('/', File.separatorChar) - hiveDevHome - .map(new File(_, stripped)) - .filter(_.exists) - .getOrElse(new File(inRepoTests, stripped)) - } - - /** - * Returns the value of specified environmental variable as a [[java.io.File]] after checking - * to ensure it exists - */ - private def envVarToFile(envVar: String): Option[File] = { - Option(System.getenv(envVar)).map(new File(_)) - } - - /** The location of the hive source code. */ - private lazy val hiveDevHome: Option[File] = envVarToFile("HIVE_DEV_HOME") - - private lazy val inRepoTests: File = - if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { - new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) - } else { - new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + - File.separator + "resources") - } -} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala deleted file mode 100644 index b18fca6bcd92..000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/SharedHiveContext.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - -import org.apache.spark.sql.test.SQLTestUtils - - -/** - * Helper trait for hive test suites where all tests share a single [[TestHiveContext]]. - * This is analogous to [[org.apache.spark.sql.test.SharedSQLContext]]. - */ -private[spark] trait SharedHiveContext extends SQLTestUtils { - - /** - * The [[TestHiveContext]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _ctx: TestHiveContext = null - - /** - * The [[TestHiveContext]] to use for all tests in this suite. - */ - protected def ctx: TestHiveContext = _ctx - protected def hiveContext: TestHiveContext = _ctx - protected override def _sqlContext: TestHiveContext = _ctx - - /** - * Initialize the [[TestHiveContext]]. - */ - protected override def beforeAll(): Unit = { - if (_ctx == null) { - _ctx = new TestHiveContext - } - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - try { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = null - } - } finally { - super.afterAll() - } - } - -} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala deleted file mode 100644 index 0556fbb3fa4c..000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHive.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - - -// Note: this should NOT be used for internal Spark unit tests because the singleton makes it -// very difficult to start a HiveContext with a custom underlying SparkContext (SPARK-9580). -@deprecated("instantiate new TestHiveContext instead", "1.5.0") -object TestHive extends TestHiveContext(TestHiveContext.defaultSparkContext()) diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index d93add697fd7..21b053f07a3b 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -32,7 +32,8 @@ import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.HiveContext; -import org.apache.spark.sql.hive.test.TestHiveContext; +import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import test.org.apache.spark.sql.hive.aggregate.MyDoubleSum; public class JavaDataFrameSuite { @@ -50,7 +51,7 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - hc = new TestHiveContext(); + hc = TestHive$.MODULE$; sc = new JavaSparkContext(hc.sparkContext()); List jsonObjects = new ArrayList(10); diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 30a35a6c06f1..15c2c3deb0d8 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -38,7 +38,7 @@ import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; import org.apache.spark.sql.hive.HiveContext; -import org.apache.spark.sql.hive.test.TestHiveContext; +import org.apache.spark.sql.hive.test.TestHive$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -63,7 +63,7 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - sqlContext = new TestHiveContext(); + sqlContext = TestHive$.MODULE$; sc = new JavaSparkContext(sqlContext.sparkContext()); originalDefaultSource = sqlContext.conf().defaultDataSourceName(); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index f8b236af7dee..39d315aaeab5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,16 +19,17 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest with SharedHiveContext { +class CachedTableSuite extends QueryTest { def rddIdOf(tableName: String): Int = { - val executedPlan = ctx.table(tableName).queryExecution.executedPlan + val executedPlan = table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -38,168 +39,168 @@ class CachedTableSuite extends QueryTest with SharedHiveContext { } def isMaterialized(rddId: Int): Boolean = { - ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache table") { - val preCacheResults = ctx.sql("SELECT * FROM src").collect().toSeq + val preCacheResults = sql("SELECT * FROM src").collect().toSeq - ctx.cacheTable("src") - assertCached(ctx.sql("SELECT * FROM src")) + cacheTable("src") + assertCached(sql("SELECT * FROM src")) checkAnswer( - ctx.sql("SELECT * FROM src"), + sql("SELECT * FROM src"), preCacheResults) - assertCached(ctx.sql("SELECT * FROM src s")) + assertCached(sql("SELECT * FROM src s")) checkAnswer( - ctx.sql("SELECT * FROM src s"), + sql("SELECT * FROM src s"), preCacheResults) - ctx.uncacheTable("src") - assertCached(ctx.sql("SELECT * FROM src"), 0) + uncacheTable("src") + assertCached(sql("SELECT * FROM src"), 0) } test("cache invalidation") { - ctx.sql("CREATE TABLE cachedTable(key INT, value STRING)") + sql("CREATE TABLE cachedTable(key INT, value STRING)") - ctx.sql("INSERT INTO TABLE cachedTable SELECT * FROM src") - checkAnswer(ctx.sql("SELECT * FROM cachedTable"), ctx.table("src").collect().toSeq) + sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) - ctx.cacheTable("cachedTable") - checkAnswer(ctx.sql("SELECT * FROM cachedTable"), ctx.table("src").collect().toSeq) + cacheTable("cachedTable") + checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) - ctx.sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + sql("INSERT INTO TABLE cachedTable SELECT * FROM src") checkAnswer( - ctx.sql("SELECT * FROM cachedTable"), - ctx.table("src").collect().toSeq ++ ctx.table("src").collect().toSeq) + sql("SELECT * FROM cachedTable"), + table("src").collect().toSeq ++ table("src").collect().toSeq) - ctx.sql("DROP TABLE cachedTable") + sql("DROP TABLE cachedTable") } test("Drop cached table") { - ctx.sql("CREATE TABLE cachedTableTest(a INT)") - ctx.cacheTable("cachedTableTest") - ctx.sql("SELECT * FROM cachedTableTest").collect() - ctx.sql("DROP TABLE cachedTableTest") + sql("CREATE TABLE cachedTableTest(a INT)") + cacheTable("cachedTableTest") + sql("SELECT * FROM cachedTableTest").collect() + sql("DROP TABLE cachedTableTest") intercept[AnalysisException] { - ctx.sql("SELECT * FROM cachedTableTest").collect() + sql("SELECT * FROM cachedTableTest").collect() } } test("DROP nonexistant table") { - ctx.sql("DROP TABLE IF EXISTS nonexistantTable") + sql("DROP TABLE IF EXISTS nonexistantTable") } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - ctx.uncacheTable("src") + TestHive.uncacheTable("src") } } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - ctx.sql("CACHE TABLE src") - assertCached(ctx.table("src")) - assert(ctx.isCached("src"), "Table 'src' should be cached") + TestHive.sql("CACHE TABLE src") + assertCached(table("src")) + assert(TestHive.isCached("src"), "Table 'src' should be cached") - ctx.sql("UNCACHE TABLE src") - assertCached(ctx.table("src"), 0) - assert(!ctx.isCached("src"), "Table 'src' should not be cached") + TestHive.sql("UNCACHE TABLE src") + assertCached(table("src"), 0) + assert(!TestHive.isCached("src"), "Table 'src' should not be cached") } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { - ctx.sql("CACHE TABLE testCacheTable AS SELECT * FROM src") - assertCached(ctx.table("testCacheTable")) + sql("CACHE TABLE testCacheTable AS SELECT * FROM src") + assertCached(table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + uncacheTable("testCacheTable") assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } test("CACHE TABLE tableName AS SELECT ...") { - ctx.sql("CACHE TABLE testCacheTable AS SELECT key FROM src LIMIT 10") - assertCached(ctx.table("testCacheTable")) + sql("CACHE TABLE testCacheTable AS SELECT key FROM src LIMIT 10") + assertCached(table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + uncacheTable("testCacheTable") assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } test("CACHE LAZY TABLE tableName") { - ctx.sql("CACHE LAZY TABLE src") - assertCached(ctx.table("src")) + sql("CACHE LAZY TABLE src") + assertCached(table("src")) val rddId = rddIdOf("src") assert( !isMaterialized(rddId), "Lazily cached in-memory table shouldn't be materialized eagerly") - ctx.sql("SELECT COUNT(*) FROM src").collect() + sql("SELECT COUNT(*) FROM src").collect() assert( isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - ctx.uncacheTable("src") + uncacheTable("src") assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } test("CACHE TABLE with Hive UDF") { - ctx.sql("CACHE TABLE udfTest AS SELECT * FROM src WHERE floor(key) = 1") - assertCached(ctx.table("udfTest")) - ctx.uncacheTable("udfTest") + sql("CACHE TABLE udfTest AS SELECT * FROM src WHERE floor(key) = 1") + assertCached(table("udfTest")) + uncacheTable("udfTest") } test("REFRESH TABLE also needs to recache the data (data source tables)") { val tempPath: File = Utils.createTempDir() tempPath.delete() - ctx.table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) - ctx.sql("DROP TABLE IF EXISTS refreshTable") - ctx.createExternalTable("refreshTable", tempPath.toString, "parquet") + table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) + sql("DROP TABLE IF EXISTS refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") checkAnswer( - ctx.table("refreshTable"), - ctx.table("src").collect()) + table("refreshTable"), + table("src").collect()) // Cache the table. - ctx.sql("CACHE TABLE refreshTable") - assertCached(ctx.table("refreshTable")) + sql("CACHE TABLE refreshTable") + assertCached(table("refreshTable")) // Append new data. - ctx.table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) + table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) // We are still using the old data. - assertCached(ctx.table("refreshTable")) + assertCached(table("refreshTable")) checkAnswer( - ctx.table("refreshTable"), - ctx.table("src").collect()) + table("refreshTable"), + table("src").collect()) // Refresh the table. - ctx.sql("REFRESH TABLE refreshTable") + sql("REFRESH TABLE refreshTable") // We are using the new data. - assertCached(ctx.table("refreshTable")) + assertCached(table("refreshTable")) checkAnswer( - ctx.table("refreshTable"), - ctx.table("src").unionAll(ctx.table("src")).collect()) + table("refreshTable"), + table("src").unionAll(table("src")).collect()) // Drop the table and create it again. - ctx.sql("DROP TABLE refreshTable") - ctx.createExternalTable("refreshTable", tempPath.toString, "parquet") + sql("DROP TABLE refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") // It is not cached. - assert(!ctx.isCached("refreshTable"), "refreshTable should not be cached.") + assert(!isCached("refreshTable"), "refreshTable should not be cached.") // Refresh the table. REFRESH TABLE command should not make a uncached // table cached. - ctx.sql("REFRESH TABLE refreshTable") + sql("REFRESH TABLE refreshTable") checkAnswer( - ctx.table("refreshTable"), - ctx.table("src").unionAll(ctx.table("src")).collect()) + table("refreshTable"), + table("src").unionAll(table("src")).collect()) // It is not cached. - assert(!ctx.isCached("refreshTable"), "refreshTable should not be cached.") + assert(!isCached("refreshTable"), "refreshTable should not be cached.") - ctx.sql("DROP TABLE refreshTable") + sql("DROP TABLE refreshTable") Utils.deleteRecursively(tempPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index fdb9725b2578..30f5313d2b81 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,15 +22,14 @@ import scala.util.Try import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.{AnalysisException, QueryTest} -class ErrorPositionSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { - import testImplicits._ +class ErrorPositionSuite extends QueryTest with BeforeAndAfter { - override def beforeAll(): Unit = { - super.beforeAll() + before { Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") } @@ -123,7 +122,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter with SharedHiveCo test(name) { val error = intercept[AnalysisException] { - quietly(ctx.sql(query)) + quietly(sql(query)) } assert(!error.getMessage.contains("Seq(")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 33dcd8a484d6..fb10f8583da9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -19,51 +19,47 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.scalatest.BeforeAndAfterAll // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with SharedHiveContext { - import testImplicits._ +class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { + private var testData: DataFrame = _ - private var _testData: DataFrame = _ - - override def beforeAll(): Unit = { - super.beforeAll() - _testData = Seq((1, 2), (2, 4)).toDF("a", "b") - ctx.registerDataFrameAsTable(_testData, "mytable") + override def beforeAll() { + testData = Seq((1, 2), (2, 4)).toDF("a", "b") + TestHive.registerDataFrameAsTable(testData, "mytable") } override def afterAll(): Unit = { - try { - ctx.dropTempTable("mytable") - } finally { - super.afterAll() - } + TestHive.dropTempTable("mytable") } test("rollup") { checkAnswer( - _testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), - ctx.sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() + testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() ) checkAnswer( - _testData.rollup("a", "b").agg(sum("b")), - ctx.sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() + testData.rollup("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() ) } test("cube") { checkAnswer( - _testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), - ctx.sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() + testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() ) checkAnswer( - _testData.cube("a", "b").agg(sum("b")), - ctx.sql("select a, b, sum(b) from mytable group by a, b with cube").collect() + testData.cube("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index 6d0ef530ca65..52e782768cb7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive.implicits._ -class HiveDataFrameJoinSuite extends QueryTest with SharedHiveContext { - import testImplicits._ +class HiveDataFrameJoinSuite extends QueryTest { // We should move this into SQL package if we make case sensitivity configurable in SQL. test("join - self join auto resolve ambiguity with case insensitivity") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 4def9557e2ae..c177cbdd991c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ -class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { - import testImplicits._ +class HiveDataFrameWindowSuite extends QueryTest { test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") @@ -54,7 +54,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - ctx.sql( + sql( """SELECT | lead(value) OVER (PARTITION BY key ORDER BY value) | FROM window_table""".stripMargin).collect()) @@ -67,7 +67,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - ctx.sql( + sql( """SELECT | lag(value) OVER (PARTITION BY key ORDER BY value) | FROM window_table""".stripMargin).collect()) @@ -80,7 +80,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - ctx.sql( + sql( """SELECT | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) | FROM window_table""".stripMargin).collect()) @@ -93,7 +93,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - ctx.sql( + sql( """SELECT | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) | FROM window_table""".stripMargin).collect()) @@ -116,7 +116,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { rank().over(Window.partitionBy("value").orderBy("key")), cumeDist().over(Window.partitionBy("value").orderBy("key")), percentRank().over(Window.partitionBy("value").orderBy("key"))), - ctx.sql( + sql( s"""SELECT |key, |max(key) over (partition by value order by key), @@ -139,7 +139,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - ctx.sql( + sql( """SELECT | avg(key) OVER | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) @@ -152,7 +152,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - ctx.sql( + sql( """SELECT | avg(key) OVER | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) @@ -170,7 +170,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { last("value").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), - ctx.sql( + sql( """SELECT | key, | last_value(value) OVER @@ -199,7 +199,7 @@ class HiveDataFrameWindowSuite extends QueryTest with SharedHiveContext { avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") ), - ctx.sql( + sql( """SELECT | key, | last_value(value) OVER diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index d1800a8f9148..59e65ff97b8e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,17 +19,18 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.sources.DataSourceTest -import org.apache.spark.sql.test.ExamplePointUDT +import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.{Logging, SparkFunSuite} -class HiveMetastoreCatalogSuite extends SparkFunSuite with SharedHiveContext with Logging { - import testImplicits._ +class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { test("struct field should accept underscore in sub-column name") { val hiveTypeStr = "struct" @@ -45,16 +46,16 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with SharedHiveContext wit } test("duplicated metastore relations") { - val df = ctx.sql("SELECT * FROM src") + val df = sql("SELECT * FROM src") logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } } -class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SharedHiveContext { - import testImplicits._ +class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { + override val sqlContext = TestHive - private lazy val testDF = ctx.range(1, 3).select( + private val testDF = range(1, 3).select( ('id + 0.1) cast DecimalType(10, 3) as 'd1, 'id cast StringType as 'd2 ).coalesce(1) @@ -80,7 +81,7 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with Shared .format(provider) .saveAsTable("t") - val hiveTable = ctx.catalog.client.getTable("default", "t") + val hiveTable = catalog.client.getTable("default", "t") assert(hiveTable.inputFormat === Some(inputFormat)) assert(hiveTable.outputFormat === Some(outputFormat)) assert(hiveTable.serde === Some(serde)) @@ -92,8 +93,8 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with Shared assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) - checkAnswer(ctx.table("t"), testDF) - assert(ctx.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -109,7 +110,7 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with Shared .option("path", path.toString) .saveAsTable("t") - val hiveTable = ctx.catalog.client.getTable("default", "t") + val hiveTable = catalog.client.getTable("default", "t") assert(hiveTable.inputFormat === Some(inputFormat)) assert(hiveTable.outputFormat === Some(outputFormat)) assert(hiveTable.serde === Some(serde)) @@ -121,8 +122,8 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with Shared assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) - checkAnswer(ctx.table("t"), testDF) - assert(ctx.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } } @@ -132,13 +133,13 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with Shared withTable("t") { val path = dir.getCanonicalPath - ctx.sql( + sql( s"""CREATE TABLE t USING $provider |OPTIONS (path '$path') |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) - val hiveTable = ctx.catalog.client.getTable("default", "t") + val hiveTable = catalog.client.getTable("default", "t") assert(hiveTable.inputFormat === Some(inputFormat)) assert(hiveTable.outputFormat === Some(outputFormat)) assert(hiveTable.serde === Some(serde)) @@ -151,8 +152,8 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with Shared assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.hiveType) === Seq("int", "string")) - checkAnswer(ctx.table("t"), Row(1, "val_1")) - assert(ctx.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + checkAnswer(table("t"), Row(1, "val_1")) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index b90e52d373b4..1fa005d5f9a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,45 +17,48 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.execution.datasources.parquet.ParquetTest -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) -class HiveParquetSuite extends QueryTest with ParquetTest with SharedHiveContext { +class HiveParquetSuite extends QueryTest with ParquetTest { + val sqlContext = TestHive + + import sqlContext._ test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { val expected = (1 to 4).map(i => Row(i.toString)) - checkAnswer(ctx.sql("SELECT upper FROM cases"), expected) - checkAnswer(ctx.sql("SELECT LOWER FROM cases"), expected) + checkAnswer(sql("SELECT upper FROM cases"), expected) + checkAnswer(sql("SELECT LOWER FROM cases"), expected) } } test("SELECT on Parquet table") { val data = (1 to 4).map(i => (i, s"val_$i")) withParquetTable(data, "t") { - checkAnswer(ctx.sql("SELECT * FROM t"), data.map(Row.fromTuple)) + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) } } test("Simple column projection + filter on Parquet table") { withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { checkAnswer( - ctx.sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), + sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), Seq(Row(true, "val_2"), Row(true, "val_4"))) } } test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => - ctx.sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( - ctx.sql("SELECT * FROM src ORDER BY key"), - ctx.sql("SELECT * from p ORDER BY key").collect().toSeq) + sql("SELECT * FROM src ORDER BY key"), + sql("SELECT * from p ORDER BY key").collect().toSeq) } } } @@ -63,14 +66,14 @@ class HiveParquetSuite extends QueryTest with ParquetTest with SharedHiveContext test("INSERT OVERWRITE TABLE Parquet table") { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => - ctx.sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - ctx.read.parquet(file.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) + read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure - ctx.sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - ctx.sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - ctx.sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - checkAnswer(ctx.sql("SELECT * FROM p"), ctx.sql("SELECT * FROM t").collect().toSeq) + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index d0b1df2e0921..0c2964611446 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -28,8 +28,8 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.hive.test.{TestHiveContext, SharedHiveContext} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.types.DecimalType import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -40,17 +40,20 @@ class HiveSparkSubmitSuite extends SparkFunSuite with Matchers with ResetSystemProperties - with Timeouts - with SharedHiveContext { + with Timeouts { // TODO: rewrite these or mark them as slow tests to be run sparingly + def beforeAll() { + System.setProperty("spark.testing", "true") + } + test("SPARK-8368: includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = TestHiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath - val jar4 = TestHiveContext.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), @@ -96,8 +99,6 @@ class HiveSparkSubmitSuite "--class", SPARK_9757.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", - "--conf", "spark.ui.enabled=false", - "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 167ad9b5d8bf..d33e81227db8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -22,120 +22,125 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.execution.QueryExecutionException -import org.apache.spark.sql.hive.test.SharedHiveContext -import org.apache.spark.sql.test.SQLTestData.TestData +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -case class ThreeColumnTable(key: Int, value: String, key1: String) +/* Implicits */ +import org.apache.spark.sql.hive.test.TestHive._ + +case class TestData(key: Int, value: String) + +case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with SharedHiveContext { - import testImplicits._ +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { + import org.apache.spark.sql.hive.test.TestHive.implicits._ - private lazy val _testData = ctx.sparkContext.parallelize( + + val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { // Since every we are doing tests for DDL statements, // it is better to reset before every test. - ctx.reset() + TestHive.reset() // Register the testData, which will be used in every test. - _testData.registerTempTable("testData") + testData.registerTempTable("testData") } test("insertInto() HiveTable") { - ctx.sql("CREATE TABLE createAndInsertTest (key int, value string)") + sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. - _testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has also been updated. checkAnswer( - ctx.sql("SELECT * FROM createAndInsertTest"), - _testData.collect().toSeq + sql("SELECT * FROM createAndInsertTest"), + testData.collect().toSeq ) // Add more data. - _testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has been updated. checkAnswer( - ctx.sql("SELECT * FROM createAndInsertTest"), - _testData.toDF().collect().toSeq ++ _testData.toDF().collect().toSeq + sql("SELECT * FROM createAndInsertTest"), + testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq ) // Now overwrite. - _testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") // Make sure the registered table has also been updated. checkAnswer( - ctx.sql("SELECT * FROM createAndInsertTest"), - _testData.collect().toSeq + sql("SELECT * FROM createAndInsertTest"), + testData.collect().toSeq ) } test("Double create fails when allowExisting = false") { - ctx.sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") val message = intercept[QueryExecutionException] { - ctx.sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") }.getMessage } test("Double create does not fail when allowExisting = true") { - ctx.sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - ctx.sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") } test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) - val rowRDD = ctx.sparkContext.parallelize( + val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = ctx.createDataFrame(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") - ctx.sql("CREATE TABLE hiveTableWithMapValue(m MAP )") - ctx.sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") + sql("CREATE TABLE hiveTableWithMapValue(m MAP )") + sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") checkAnswer( - ctx.sql("SELECT * FROM hiveTableWithMapValue"), + sql("SELECT * FROM hiveTableWithMapValue"), rowRDD.collect().toSeq ) - ctx.sql("DROP TABLE hiveTableWithMapValue") + sql("DROP TABLE hiveTableWithMapValue") } test("SPARK-4203:random partition directory order") { - ctx.sql("CREATE TABLE tmp_table (key int, value string)") + sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Utils.createTempDir() val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - ctx.sql( + sql( s""" |CREATE TABLE table_with_partition(c1 string) |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) |location '${tmpDir.toURI.toString}' """.stripMargin) - ctx.sql( + sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (p1='a',p2='b',p3='c',p4='c',p5='1') |SELECT 'blarr' FROM tmp_table """.stripMargin) - ctx.sql( + sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (p1='a',p2='b',p3='c',p4='c',p5='2') |SELECT 'blarr' FROM tmp_table """.stripMargin) - ctx.sql( + sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (p1='a',p2='b',p3='c',p4='c',p5='3') |SELECT 'blarr' FROM tmp_table """.stripMargin) - ctx.sql( + sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (p1='a',p2='b',p3='c',p4='c',p5='4') @@ -157,104 +162,104 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter with Shared "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) - ctx.sql("DROP TABLE table_with_partition") - ctx.sql("DROP TABLE tmp_table") + sql("DROP TABLE table_with_partition") + sql("DROP TABLE tmp_table") } test("Insert ArrayType.containsNull == false") { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) - val rowRDD = ctx.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = ctx.createDataFrame(rowRDD, schema) + val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") - ctx.sql("CREATE TABLE hiveTableWithArrayValue(a Array )") - ctx.sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") + sql("CREATE TABLE hiveTableWithArrayValue(a Array )") + sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") checkAnswer( - ctx.sql("SELECT * FROM hiveTableWithArrayValue"), + sql("SELECT * FROM hiveTableWithArrayValue"), rowRDD.collect().toSeq) - ctx.sql("DROP TABLE hiveTableWithArrayValue") + sql("DROP TABLE hiveTableWithArrayValue") } test("Insert MapType.valueContainsNull == false") { val schema = StructType(Seq( StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) - val rowRDD = ctx.sparkContext.parallelize( + val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = ctx.createDataFrame(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") - ctx.sql("CREATE TABLE hiveTableWithMapValue(m Map )") - ctx.sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") + sql("CREATE TABLE hiveTableWithMapValue(m Map )") + sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") checkAnswer( - ctx.sql("SELECT * FROM hiveTableWithMapValue"), + sql("SELECT * FROM hiveTableWithMapValue"), rowRDD.collect().toSeq) - ctx.sql("DROP TABLE hiveTableWithMapValue") + sql("DROP TABLE hiveTableWithMapValue") } test("Insert StructType.fields.exists(_.nullable == false)") { val schema = StructType(Seq( StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) - val rowRDD = ctx.sparkContext.parallelize( + val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = ctx.createDataFrame(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") - ctx.sql("CREATE TABLE hiveTableWithStructValue(s Struct )") - ctx.sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") + sql("CREATE TABLE hiveTableWithStructValue(s Struct )") + sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") checkAnswer( - ctx.sql("SELECT * FROM hiveTableWithStructValue"), + sql("SELECT * FROM hiveTableWithStructValue"), rowRDD.collect().toSeq) - ctx.sql("DROP TABLE hiveTableWithStructValue") + sql("DROP TABLE hiveTableWithStructValue") } test("SPARK-5498:partition schema does not match table schema") { - val testData = ctx.sparkContext.parallelize( + val testData = TestHive.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") - val testDatawithNull = ctx.sparkContext.parallelize( - (1 to 10).map(i => ThreeColumnTable(i, i.toString, null))).toDF() + val testDatawithNull = TestHive.sparkContext.parallelize( + (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() - ctx.sql( + sql( s""" |CREATE TABLE table_with_partition(key int,value string) |PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' """.stripMargin) - ctx.sql( + sql( """ |INSERT OVERWRITE TABLE table_with_partition |partition (ds='1') SELECT key,value FROM testData """.stripMargin) // test schema the same between partition and table - ctx.sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") - checkAnswer(ctx.sql("select key,value from table_with_partition where ds='1' "), + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), testData.collect().toSeq ) // test difference type of field - ctx.sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") - checkAnswer(ctx.sql("select key,value from table_with_partition where ds='1' "), + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), testData.collect().toSeq ) // add column to table - ctx.sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") - checkAnswer(ctx.sql("select key,value,key1 from table_with_partition where ds='1' "), + sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") + checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), testDatawithNull.collect().toSeq ) // change column name to table - ctx.sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") - checkAnswer(ctx.sql("select keynew,value from table_with_partition where ds='1' "), + sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") + checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), testData.collect().toSeq ) - ctx.sql("DROP TABLE table_with_partition") + sql("DROP TABLE table_with_partition") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 1a2b5eb01e09..1c15997ea8e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -17,39 +17,39 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.scalatest.BeforeAndAfterAll -class ListTablesSuite extends QueryTest with SharedHiveContext { - import testImplicits._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.Row - private lazy val df = - ctx.sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") +class ListTablesSuite extends QueryTest with BeforeAndAfterAll { + + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + val df = + sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { - super.beforeAll() // The catalog in HiveContext is a case insensitive one. - ctx.catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) - ctx.catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) - ctx.sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") - ctx.sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") - ctx.sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") + catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) + catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) + sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") + sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") + sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") } override def afterAll(): Unit = { - try { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) - ctx.sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") - ctx.sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") - ctx.sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") - } finally { - super.afterAll() - } + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) + sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") + sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") + sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") } test("get all tables of current database") { - Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { + Seq(tables(), sql("SHOW TABLes")).foreach { case allTables => // We are using default DB. checkAnswer( @@ -64,7 +64,7 @@ class ListTablesSuite extends QueryTest with SharedHiveContext { } test("getting all tables with a database name") { - Seq(ctx.tables("listtablessuiteDb"), ctx.sql("SHOW TABLes in listTablesSuitedb")).foreach { + Seq(tables("listtablessuiteDb"), sql("SHOW TABLes in listTablesSuitedb")).foreach { case allTables => checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e1d2804d15b7..7f36a483a396 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -22,36 +22,37 @@ import java.io.{IOException, File} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.InvalidInputException +import org.scalatest.BeforeAndAfterAll import org.apache.spark.Logging import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite - extends QueryTest - with SharedHiveContext +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll with Logging { - - import testImplicits._ + override val sqlContext = TestHive var jsonFilePath: String = _ override def beforeAll(): Unit = { - super.beforeAll() jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } test("persistent JSON table") { withTable("jsonTable") { - ctx.sql( + sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -60,14 +61,14 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), - ctx.read.json(jsonFilePath).collect().toSeq) + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) } } test("persistent JSON table with a user specified schema") { withTable("jsonTable") { - ctx.sql( + sql( s"""CREATE TABLE jsonTable ( |a string, |b String, @@ -80,10 +81,10 @@ class MetastoreDataSourcesSuite """.stripMargin) withTempTable("expectedJsonTable") { - ctx.read.json(jsonFilePath).registerTempTable("expectedJsonTable") + read.json(jsonFilePath).registerTempTable("expectedJsonTable") checkAnswer( - ctx.sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), - ctx.sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) } } } @@ -92,7 +93,7 @@ class MetastoreDataSourcesSuite withTable("jsonTable") { // This works because JSON objects are self-describing and JSONRelation can get needed // field values based on field names. - ctx.sql( + sql( s"""CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -107,20 +108,20 @@ class MetastoreDataSourcesSuite StructField("", innerStruct, true), StructField("b", StringType, true))) - assert(expectedSchema === ctx.table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) withTempTable("expectedJsonTable") { - ctx.read.json(jsonFilePath).registerTempTable("expectedJsonTable") + read.json(jsonFilePath).registerTempTable("expectedJsonTable") checkAnswer( - ctx.sql("SELECT b, ``.`=` FROM jsonTable"), - ctx.sql("SELECT b, ``.`=` FROM expectedJsonTable")) + sql("SELECT b, ``.`=` FROM jsonTable"), + sql("SELECT b, ``.`=` FROM expectedJsonTable")) } } } test("resolve shortened provider names") { withTable("jsonTable") { - ctx.sql( + sql( s""" |CREATE TABLE jsonTable |USING org.apache.spark.sql.json @@ -130,14 +131,14 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), - ctx.read.json(jsonFilePath).collect().toSeq) + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) } } test("drop table") { withTable("jsonTable") { - ctx.sql( + sql( s""" |CREATE TABLE jsonTable |USING org.apache.spark.sql.json @@ -147,13 +148,13 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), - ctx.read.json(jsonFilePath)) + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath)) - ctx.sql("DROP TABLE jsonTable") + sql("DROP TABLE jsonTable") intercept[Exception] { - ctx.sql("SELECT * FROM jsonTable").collect() + sql("SELECT * FROM jsonTable").collect() } assert( @@ -168,7 +169,7 @@ class MetastoreDataSourcesSuite withTable("jsonTable") { (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) - ctx.sql( + sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json |OPTIONS ( @@ -177,7 +178,7 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), + sql("SELECT * FROM jsonTable"), Row("a", "b")) Utils.deleteRecursively(tempDir) @@ -186,14 +187,14 @@ class MetastoreDataSourcesSuite // Schema is cached so the new column does not show. The updated values in existing columns // will show. checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), + sql("SELECT * FROM jsonTable"), Row("a1", "b1")) - ctx.sql("REFRESH TABLE jsonTable") + sql("REFRESH TABLE jsonTable") // Check that the refresh worked checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), + sql("SELECT * FROM jsonTable"), Row("a1", "b1", "c1")) } } @@ -204,7 +205,7 @@ class MetastoreDataSourcesSuite (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) withTable("jsonTable") { - ctx.sql( + sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json |OPTIONS ( @@ -213,15 +214,15 @@ class MetastoreDataSourcesSuite """.stripMargin) checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), + sql("SELECT * FROM jsonTable"), Row("a", "b")) Utils.deleteRecursively(tempDir) (("a", "b", "c") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) - ctx.sql("DROP TABLE jsonTable") + sql("DROP TABLE jsonTable") - ctx.sql( + sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json |OPTIONS ( @@ -231,7 +232,7 @@ class MetastoreDataSourcesSuite // New table should reflect new schema. checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), + sql("SELECT * FROM jsonTable"), Row("a", "b", "c")) } } @@ -239,7 +240,7 @@ class MetastoreDataSourcesSuite test("invalidate cache and reload") { withTable("jsonTable") { - ctx.sql( + sql( s"""CREATE TABLE jsonTable (`c_!@(3)` int) |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -248,23 +249,23 @@ class MetastoreDataSourcesSuite """.stripMargin) withTempTable("expectedJsonTable") { - ctx.read.json(jsonFilePath).registerTempTable("expectedJsonTable") + read.json(jsonFilePath).registerTempTable("expectedJsonTable") checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), - ctx.sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) // Discard the cached relation. - ctx.invalidateTable("jsonTable") + invalidateTable("jsonTable") checkAnswer( - ctx.sql("SELECT * FROM jsonTable"), - ctx.sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - ctx.invalidateTable("jsonTable") + invalidateTable("jsonTable") val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) - assert(expectedSchema === ctx.table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) } } } @@ -272,7 +273,7 @@ class MetastoreDataSourcesSuite test("CTAS") { withTempPath { tempPath => withTable("jsonTable", "ctasJsonTable") { - ctx.sql( + sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -280,7 +281,7 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - ctx.sql( + sql( s"""CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -289,11 +290,11 @@ class MetastoreDataSourcesSuite |SELECT * FROM jsonTable """.stripMargin) - assert(ctx.table("ctasJsonTable").schema === ctx.table("jsonTable").schema) + assert(table("ctasJsonTable").schema === table("jsonTable").schema) checkAnswer( - ctx.sql("SELECT * FROM ctasJsonTable"), - ctx.sql("SELECT * FROM jsonTable").collect()) + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) } } } @@ -303,7 +304,7 @@ class MetastoreDataSourcesSuite val tempPath = path.getCanonicalPath withTable("jsonTable", "ctasJsonTable") { - ctx.sql( + sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -311,7 +312,7 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - ctx.sql( + sql( s"""CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -322,7 +323,7 @@ class MetastoreDataSourcesSuite // Create the table again should trigger a AnalysisException. val message = intercept[AnalysisException] { - ctx.sql( + sql( s"""CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -339,7 +340,7 @@ class MetastoreDataSourcesSuite // The following statement should be fine if it has IF NOT EXISTS. // It tries to create a table ctasJsonTable with a new schema. // The actual table's schema and data should not be changed. - ctx.sql( + sql( s"""CREATE TABLE IF NOT EXISTS ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -349,21 +350,21 @@ class MetastoreDataSourcesSuite """.stripMargin) // Discard the cached relation. - ctx.invalidateTable("ctasJsonTable") + invalidateTable("ctasJsonTable") // Schema should not be changed. - assert(ctx.table("ctasJsonTable").schema === ctx.table("jsonTable").schema) + assert(table("ctasJsonTable").schema === table("jsonTable").schema) // Table data should not be changed. checkAnswer( - ctx.sql("SELECT * FROM ctasJsonTable"), - ctx.sql("SELECT * FROM jsonTable").collect()) + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) } } } test("CTAS a managed table") { withTable("jsonTable", "ctasJsonTable", "loadedTable") { - ctx.sql( + sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -371,13 +372,13 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - val expectedPath = ctx.catalog.hiveDefaultTableFilePath("ctasJsonTable") + val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(ctx.sparkContext.hadoopConfiguration) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) // It is a managed table when we do not specify the location. - ctx.sql( + sql( s"""CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |AS @@ -386,7 +387,7 @@ class MetastoreDataSourcesSuite assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - ctx.sql( + sql( s"""CREATE TABLE loadedTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -394,20 +395,20 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - assert(ctx.table("ctasJsonTable").schema === ctx.table("loadedTable").schema) + assert(table("ctasJsonTable").schema === table("loadedTable").schema) checkAnswer( - ctx.sql("SELECT * FROM ctasJsonTable"), - ctx.sql("SELECT * FROM loadedTable")) + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM loadedTable")) - ctx.sql("DROP TABLE ctasJsonTable") + sql("DROP TABLE ctasJsonTable") assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") } } test("SPARK-5286 Fail to drop an invalid table when using the data source API") { withTable("jsonTable") { - ctx.sql( + sql( s"""CREATE TABLE jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( @@ -415,7 +416,7 @@ class MetastoreDataSourcesSuite |) """.stripMargin) - ctx.sql("DROP TABLE jsonTable").collect().foreach(i => logInfo(i.toString)) + sql("DROP TABLE jsonTable").collect().foreach(i => logInfo(i.toString)) } } @@ -430,21 +431,21 @@ class MetastoreDataSourcesSuite .saveAsTable("savedJsonTable") checkAnswer( - ctx.sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), (1 to 4).map(i => Row(i, s"str$i"))) checkAnswer( - ctx.sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), (6 to 10).map(i => Row(i, s"str$i"))) - ctx.invalidateTable("savedJsonTable") + invalidateTable("savedJsonTable") checkAnswer( - ctx.sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), (1 to 4).map(i => Row(i, s"str$i"))) checkAnswer( - ctx.sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), (6 to 10).map(i => Row(i, s"str$i"))) } } @@ -460,23 +461,23 @@ class MetastoreDataSourcesSuite // Save the df as a managed table (by not specifying the path). df.write.saveAsTable("savedJsonTable") - checkAnswer(ctx.sql("SELECT * FROM savedJsonTable"), df) + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) // We can overwrite it. df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") - checkAnswer(ctx.sql("SELECT * FROM savedJsonTable"), df) + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) // When the save mode is Ignore, we will do nothing when the table already exists. df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") // TODO in ResolvedDataSource, will convert the schema into nullable = true // hence the df.schema is not exactly the same as table("savedJsonTable").schema // assert(df.schema === table("savedJsonTable").schema) - checkAnswer(ctx.sql("SELECT * FROM savedJsonTable"), df) + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) // Drop table will also delete the data. - ctx.sql("DROP TABLE savedJsonTable") + sql("DROP TABLE savedJsonTable") intercept[IOException] { - ctx.read.json(ctx.catalog.hiveDefaultTableFilePath("savedJsonTable")) + read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) } } @@ -488,12 +489,12 @@ class MetastoreDataSourcesSuite .option("path", tempPath.toString) .saveAsTable("savedJsonTable") - checkAnswer(ctx.sql("SELECT * FROM savedJsonTable"), df) + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) } // Data should not be deleted after we drop the table. - ctx.sql("DROP TABLE savedJsonTable") - checkAnswer(ctx.read.json(tempPath.toString), df) + sql("DROP TABLE savedJsonTable") + checkAnswer(read.json(tempPath.toString), df) } } } @@ -501,7 +502,7 @@ class MetastoreDataSourcesSuite test("create external table") { withTempPath { tempPath => withTable("savedJsonTable", "createdJsonTable") { - val df = ctx.read.json(ctx.sparkContext.parallelize((1 to 10).map { i => + val df = read.json(sparkContext.parallelize((1 to 10).map { i => s"""{ "a": $i, "b": "str$i" }""" })) @@ -514,39 +515,39 @@ class MetastoreDataSourcesSuite } withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { - ctx.createExternalTable("createdJsonTable", tempPath.toString) - assert(ctx.table("createdJsonTable").schema === df.schema) - checkAnswer(ctx.sql("SELECT * FROM createdJsonTable"), df) + createExternalTable("createdJsonTable", tempPath.toString) + assert(table("createdJsonTable").schema === df.schema) + checkAnswer(sql("SELECT * FROM createdJsonTable"), df) assert( intercept[AnalysisException] { - ctx.createExternalTable("createdJsonTable", jsonFilePath.toString) + createExternalTable("createdJsonTable", jsonFilePath.toString) }.getMessage.contains("Table createdJsonTable already exists."), "We should complain that createdJsonTable already exists") } // Data should not be deleted. - ctx.sql("DROP TABLE createdJsonTable") - checkAnswer(ctx.read.json(tempPath.toString), df) + sql("DROP TABLE createdJsonTable") + checkAnswer(read.json(tempPath.toString), df) // Try to specify the schema. withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { val schema = StructType(StructField("b", StringType, true) :: Nil) - ctx.createExternalTable( + createExternalTable( "createdJsonTable", "org.apache.spark.sql.json", schema, Map("path" -> tempPath.toString)) checkAnswer( - ctx.sql("SELECT * FROM createdJsonTable"), - ctx.sql("SELECT b FROM savedJsonTable")) + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable")) - ctx.sql("DROP TABLE createdJsonTable") + sql("DROP TABLE createdJsonTable") assert( intercept[RuntimeException] { - ctx.createExternalTable( + createExternalTable( "createdJsonTable", "org.apache.spark.sql.json", schema, @@ -564,16 +565,16 @@ class MetastoreDataSourcesSuite (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") withTable("test_parquet_ctas") { - ctx.sql( + sql( """CREATE TABLE test_parquet_ctas STORED AS PARQUET |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 """.stripMargin) checkAnswer( - ctx.sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), Row(3) :: Row(4) :: Nil) - ctx.table("test_parquet_ctas").queryExecution.optimizedPlan match { + table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(p: ParquetRelation) => // OK case _ => fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation]}") @@ -629,10 +630,10 @@ class MetastoreDataSourcesSuite .mode(SaveMode.Append) .saveAsTable("arrayInParquet") - ctx.refreshTable("arrayInParquet") + refreshTable("arrayInParquet") checkAnswer( - ctx.sql("SELECT a FROM arrayInParquet"), + sql("SELECT a FROM arrayInParquet"), Row(ArrayBuffer(1, null)) :: Row(ArrayBuffer(2, 3)) :: Row(ArrayBuffer(4, 5)) :: @@ -688,10 +689,10 @@ class MetastoreDataSourcesSuite .mode(SaveMode.Append) .saveAsTable("mapInParquet") - ctx.refreshTable("mapInParquet") + refreshTable("mapInParquet") checkAnswer( - ctx.sql("SELECT a FROM mapInParquet"), + sql("SELECT a FROM mapInParquet"), Row(Map(1 -> null)) :: Row(Map(2 -> 3)) :: Row(Map(4 -> 5)) :: @@ -706,7 +707,7 @@ class MetastoreDataSourcesSuite val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) // Manually create a metastore data source table. - ctx.catalog.createDataSourceTable( + catalog.createDataSourceTable( tableName = "wide_schema", userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -714,9 +715,9 @@ class MetastoreDataSourcesSuite options = Map("path" -> "just a dummy path"), isExternal = false) - ctx.invalidateTable("wide_schema") + invalidateTable("wide_schema") - val actualSchema = ctx.table("wide_schema").schema + val actualSchema = table("wide_schema").schema assert(schema === actualSchema) } } @@ -737,12 +738,12 @@ class MetastoreDataSourcesSuite "EXTERNAL" -> "FALSE"), tableType = ManagedTable, serdeProperties = Map( - "path" -> ctx.catalog.hiveDefaultTableFilePath(tableName))) + "path" -> catalog.hiveDefaultTableFilePath(tableName))) - ctx.catalog.client.createTable(hiveTable) + catalog.client.createTable(hiveTable) - ctx.invalidateTable(tableName) - val actualSchema = ctx.table(tableName).schema + invalidateTable(tableName) + val actualSchema = table(tableName).schema assert(schema === actualSchema) } } @@ -753,8 +754,8 @@ class MetastoreDataSourcesSuite withTable(tableName) { df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) - ctx.invalidateTable(tableName) - val metastoreTable = ctx.catalog.client.getTable("default", tableName) + invalidateTable(tableName) + val metastoreTable = catalog.client.getTable("default", tableName) val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) val actualPartitionColumns = StructType( @@ -768,7 +769,7 @@ class MetastoreDataSourcesSuite // Check the content of the saved table. checkAnswer( - ctx.table(tableName).select("c", "b", "d", "a"), + table(tableName).select("c", "b", "d", "a"), df.select("c", "b", "d", "a")) } } @@ -781,7 +782,7 @@ class MetastoreDataSourcesSuite withTable("insertParquet") { createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") checkAnswer( - ctx.sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), (6 to 9).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { @@ -790,12 +791,12 @@ class MetastoreDataSourcesSuite createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") checkAnswer( - ctx.sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), (6 to 19).map(i => Row(i, s"str$i"))) createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") checkAnswer( - ctx.sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), (6 to 24).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { @@ -804,26 +805,26 @@ class MetastoreDataSourcesSuite createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") checkAnswer( - ctx.sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), (6 to 34).map(i => Row(i, s"str$i"))) createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") checkAnswer( - ctx.sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), (6 to 44).map(i => Row(i, s"str$i"))) createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") checkAnswer( - ctx.sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), (52 to 54).map(i => Row(i, s"str$i"))) createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") checkAnswer( - ctx.sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM insertParquet p"), (50 to 59).map(i => Row(i, s"str$i"))) createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") checkAnswer( - ctx.sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM insertParquet p"), (70 to 79).map(i => Row(i, s"str$i"))) } } @@ -831,17 +832,17 @@ class MetastoreDataSourcesSuite test("SPARK-8156:create table to specific database by 'use dbname' ") { val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") - ctx.sql("""create database if not exists testdb8156""") - ctx.sql("""use testdb8156""") + sqlContext.sql("""create database if not exists testdb8156""") + sqlContext.sql("""use testdb8156""") df.write .format("parquet") .mode(SaveMode.Overwrite) .saveAsTable("ttt3") checkAnswer( - ctx.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), + sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), Row("ttt3", false)) - ctx.sql("""use default""") - ctx.sql("""drop database if exists testdb8156 CASCADE""") + sqlContext.sql("""use default""") + sqlContext.sql("""drop database if exists testdb8156 CASCADE""") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 1ff256d3b209..73852f13ad20 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,30 +17,35 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.SharedHiveContext -import org.apache.spark.sql.{QueryTest, SaveMode} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} -class MultiDatabaseSuite extends QueryTest with SharedHiveContext { - private lazy val df = ctx.range(10).coalesce(1) +class MultiDatabaseSuite extends QueryTest with SQLTestUtils { + override val sqlContext: SQLContext = TestHive + + import sqlContext.sql + + private val df = sqlContext.range(10).coalesce(1) test(s"saveAsTable() to non-default database - with USE - Overwrite") { withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(ctx.tableNames().contains("t")) - checkAnswer(ctx.table("t"), df) + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df) } - assert(ctx.tableNames(db).contains("t")) - checkAnswer(ctx.table(s"$db.t"), df) + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) } } test(s"saveAsTable() to non-default database - without USE - Overwrite") { withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") - assert(ctx.tableNames(db).contains("t")) - checkAnswer(ctx.table(s"$db.t"), df) + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) } } @@ -49,12 +54,12 @@ class MultiDatabaseSuite extends QueryTest with SharedHiveContext { activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") df.write.mode(SaveMode.Append).saveAsTable("t") - assert(ctx.tableNames().contains("t")) - checkAnswer(ctx.table("t"), df.unionAll(df)) + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df.unionAll(df)) } - assert(ctx.tableNames(db).contains("t")) - checkAnswer(ctx.table(s"$db.t"), df.unionAll(df)) + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) } } @@ -62,8 +67,8 @@ class MultiDatabaseSuite extends QueryTest with SharedHiveContext { withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") - assert(ctx.tableNames(db).contains("t")) - checkAnswer(ctx.table(s"$db.t"), df.unionAll(df)) + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) } } @@ -71,10 +76,10 @@ class MultiDatabaseSuite extends QueryTest with SharedHiveContext { withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(ctx.tableNames().contains("t")) + assert(sqlContext.tableNames().contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(ctx.table(s"$db.t"), df.unionAll(df)) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) } } } @@ -83,46 +88,46 @@ class MultiDatabaseSuite extends QueryTest with SharedHiveContext { withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(ctx.tableNames().contains("t")) + assert(sqlContext.tableNames().contains("t")) } - assert(ctx.tableNames(db).contains("t")) + assert(sqlContext.tableNames(db).contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(ctx.table(s"$db.t"), df.unionAll(df)) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) } } test("Looks up tables in non-default database") { withTempDatabase { db => activateDatabase(db) { - ctx.sql("CREATE TABLE t (key INT)") - checkAnswer(ctx.table("t"), ctx.emptyDataFrame) + sql("CREATE TABLE t (key INT)") + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) } - checkAnswer(ctx.table(s"$db.t"), ctx.emptyDataFrame) + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) } } test("Drops a table in a non-default database") { withTempDatabase { db => activateDatabase(db) { - ctx.sql(s"CREATE TABLE t (key INT)") - assert(ctx.tableNames().contains("t")) - assert(!ctx.tableNames("default").contains("t")) + sql(s"CREATE TABLE t (key INT)") + assert(sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) } - assert(!ctx.tableNames().contains("t")) - assert(ctx.tableNames(db).contains("t")) + assert(!sqlContext.tableNames().contains("t")) + assert(sqlContext.tableNames(db).contains("t")) activateDatabase(db) { - ctx.sql(s"DROP TABLE t") - assert(!ctx.tableNames().contains("t")) - assert(!ctx.tableNames("default").contains("t")) + sql(s"DROP TABLE t") + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) } - assert(!ctx.tableNames().contains("t")) - assert(!ctx.tableNames(db).contains("t")) + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames(db).contains("t")) } } @@ -134,19 +139,19 @@ class MultiDatabaseSuite extends QueryTest with SharedHiveContext { val path = dir.getCanonicalPath activateDatabase(db) { - ctx.sql( + sql( s"""CREATE EXTERNAL TABLE t (id BIGINT) |PARTITIONED BY (p INT) |STORED AS PARQUET |LOCATION '$path' """.stripMargin) - checkAnswer(ctx.table("t"), ctx.emptyDataFrame) + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) df.write.parquet(s"$path/p=1") - ctx.sql("ALTER TABLE t ADD PARTITION (p=1)") - ctx.sql("REFRESH TABLE t") - checkAnswer(ctx.table("t"), df.withColumn("p", lit(1))) + sql("ALTER TABLE t ADD PARTITION (p=1)") + sql("REFRESH TABLE t") + checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 94b866340d33..251e0324bfa5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf -import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.{Row, SQLConf, SQLContext} -class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with SharedHiveContext { +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { import ParquetCompatibilityTest.makeNullable + override val sqlContext: SQLContext = TestHive + /** * Set the staging directory (and hence path to ignore Parquet files under) * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. @@ -36,31 +38,33 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with Shared withTable("parquet_compat") { withTempPath { dir => val path = dir.getCanonicalPath + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { withTempTable("data") { - ctx.sql( + sqlContext.sql( s"""CREATE TABLE parquet_compat( - | bool_column BOOLEAN, - | byte_column TINYINT, - | short_column SMALLINT, - | int_column INT, - | long_column BIGINT, - | float_column FLOAT, - | double_column DOUBLE, - | - | strings_column ARRAY, - | int_to_string_column MAP - |) - |STORED AS PARQUET - |LOCATION '$path' - """.stripMargin) + | bool_column BOOLEAN, + | byte_column TINYINT, + | short_column SMALLINT, + | int_column INT, + | long_column BIGINT, + | float_column FLOAT, + | double_column DOUBLE, + | + | strings_column ARRAY, + | int_to_string_column MAP + |) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) - val schema = ctx.table("parquet_compat").schema - val rowRDD = ctx.sparkContext.parallelize(makeRows).coalesce(1) - ctx.createDataFrame(rowRDD, schema).registerTempTable("data") - ctx.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") } } + val schema = readParquetSchema(path, { path => !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) }) @@ -73,13 +77,13 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with Shared // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. // Have to assume all BINARY values are strings here. withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(ctx.read.parquet(path), makeRows) + checkAnswer(sqlContext.read.parquet(path), makeRows) } } } } - private def makeRows: Seq[Row] = { + def makeRows: Seq[Row] = { (0 until 10).map { i => def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 418205aee0dd..017bc2adc103 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.sql.hive.test.SharedHiveContext -import org.apache.spark.sql.test.SQLTestData.TestData -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.util.Utils -class QueryPartitionSuite extends QueryTest with SharedHiveContext { - import testImplicits._ +class QueryPartitionSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ + import ctx.sql test("SPARK-5068: query data when path doesn't exist"){ val testData = ctx.sparkContext.parallelize( @@ -35,19 +36,19 @@ class QueryPartitionSuite extends QueryTest with SharedHiveContext { val tmpDir = Files.createTempDir() // create the table for test - ctx.sql(s"CREATE TABLE table_with_partition(key int,value string) " + + sql(s"CREATE TABLE table_with_partition(key int,value string) " + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") - ctx.sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + "SELECT key,value FROM testData") - ctx.sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + "SELECT key,value FROM testData") - ctx.sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + "SELECT key,value FROM testData") - ctx.sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + "SELECT key,value FROM testData") // test for the exist path - checkAnswer(ctx.sql("select key,value from table_with_partition"), + checkAnswer(sql("select key,value from table_with_partition"), testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) @@ -57,10 +58,10 @@ class QueryPartitionSuite extends QueryTest with SharedHiveContext { .foreach { f => Utils.deleteRecursively(f) } // test for after delete the path - checkAnswer(ctx.sql("select key,value from table_with_partition"), + checkAnswer(sql("select key,value from table_with_partition"), testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - ctx.sql("DROP TABLE table_with_partition") - ctx.sql("DROP TABLE createAndInsertTest") + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index 7ce385f1a613..93dcb10f7a29 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.hive import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.hive.test.SharedHiveContext -class SerializationSuite extends SparkFunSuite with SharedHiveContext { +class SerializationSuite extends SparkFunSuite { test("[SPARK-5840] HiveContext should be serializable") { - ctx.hiveconf + val hiveContext = org.apache.spark.sql.hive.test.TestHive + hiveContext.hiveconf val serializer = new JavaSerializer(new SparkConf()).newInstance() - val bytes = serializer.serialize(ctx) - serializer.deserialize[AnyRef](bytes) + val bytes = serializer.serialize(hiveContext) + val deSer = serializer.deserialize[AnyRef](bytes) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 7dcedaa02153..e4fec7e2c8a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -17,21 +17,25 @@ package org.apache.spark.sql.hive +import org.scalatest.BeforeAndAfterAll + import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.hive.test.SharedHiveContext -class StatisticsSuite extends QueryTest with SharedHiveContext { +class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - protected override def beforeAll(): Unit = { - super.beforeAll() + private lazy val ctx: HiveContext = { + val ctx = org.apache.spark.sql.hive.test.TestHive ctx.reset() ctx.cacheTables = false + ctx } + import ctx.sql + test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { val parsed = HiveQl.parseSql(analyzeCommand) @@ -79,32 +83,32 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table - ctx.sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() - ctx.sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - ctx.sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() + sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - ctx.sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") + sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) - ctx.sql("DROP TABLE analyzeTable").collect() + sql("DROP TABLE analyzeTable").collect() // Partitioned table - ctx.sql( + sql( """ |CREATE TABLE analyzeTable_part (key STRING, value STRING) PARTITIONED BY (ds STRING) """.stripMargin).collect() - ctx.sql( + sql( """ |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-01') |SELECT * FROM src """.stripMargin).collect() - ctx.sql( + sql( """ |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-02') |SELECT * FROM src """.stripMargin).collect() - ctx.sql( + sql( """ |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-03') |SELECT * FROM src @@ -112,14 +116,14 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) - ctx.sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") + sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) - ctx.sql("DROP TABLE analyzeTable_part").collect() + sql("DROP TABLE analyzeTable_part").collect() // Try to analyze a temp table - ctx.sql("""SELECT * FROM src""").registerTempTable("tempTable") + sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { ctx.analyze("tempTable") } @@ -127,7 +131,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { } test("estimates the size of a test MetastoreRelation") { - val df = ctx.sql("""SELECT * FROM src""") + val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } @@ -145,7 +149,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { ct: ClassTag[_]): Unit = { before() - var df = ctx.sql(query) + var df = sql(query) // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { @@ -166,8 +170,8 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { ctx.conf.settings.synchronized { val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") - df = ctx.sql(query) + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") + df = sql(query) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") @@ -175,7 +179,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") - ctx.sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") } after() @@ -199,7 +203,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin val answer = Row(86, "val_86") - var df = ctx.sql(leftSemiJoinQuery) + var df = sql(leftSemiJoinQuery) // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { @@ -224,8 +228,8 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { ctx.conf.settings.synchronized { val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - df = ctx.sql(leftSemiJoinQuery) + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") + df = sql(leftSemiJoinQuery) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j } @@ -237,7 +241,7 @@ class StatisticsSuite extends QueryTest with SharedHiveContext { assert(shj.size === 1, "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 280587076755..9b3ede43ee2d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.{Row, QueryTest} case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with SharedHiveContext { +class UDFSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 8d1e9c92e992..7b5aa4763fd9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,21 +17,24 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql._ +import org.scalatest.BeforeAndAfterAll import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} -abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { - import testImplicits._ +abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + + override val sqlContext = TestHive + import sqlContext.implicits._ var originalUseAggregate2: Boolean = _ override def beforeAll(): Unit = { - super.beforeAll() - originalUseAggregate2 = ctx.conf.useSqlAggregate2 - ctx.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") + originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), @@ -64,31 +67,27 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { (3, null, null)).toDF("key", "value1", "value2") data2.write.saveAsTable("agg2") - val emptyDF = ctx.createDataFrame( - ctx.sparkContext.emptyRDD[Row], + val emptyDF = sqlContext.createDataFrame( + sqlContext.sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) emptyDF.registerTempTable("emptyTable") // Register UDAFs - ctx.udf.register("mydoublesum", new MyDoubleSum) - ctx.udf.register("mydoubleavg", new MyDoubleAvg) + sqlContext.udf.register("mydoublesum", new MyDoubleSum) + sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) } override def afterAll(): Unit = { - try { - ctx.sql("DROP TABLE IF EXISTS agg1") - ctx.sql("DROP TABLE IF EXISTS agg2") - ctx.dropTempTable("emptyTable") - ctx.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) - } finally { - super.afterAll() - } + sqlContext.sql("DROP TABLE IF EXISTS agg1") + sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.dropTempTable("emptyTable") + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } test("empty table") { // If there is no GROUP BY clause and the table is empty, we will generate a single row. checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | AVG(value), @@ -105,7 +104,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(null, 0, 0, 0, null, null, null, null, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | AVG(value), @@ -124,7 +123,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { // If there is a GROUP BY clause and the table is empty, there is no output. checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | AVG(value), @@ -144,7 +143,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("null literal") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | AVG(null), @@ -160,7 +159,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("only do grouping") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT key |FROM agg1 @@ -169,7 +168,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT DISTINCT value1, key |FROM agg2 @@ -186,7 +185,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(null, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT value1, key |FROM agg2 @@ -206,7 +205,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("case in-sensitive resolution") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT avg(value), kEY - 100 |FROM agg1 @@ -215,7 +214,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT sum(distinct value1), kEY - 100, count(distinct value1) |FROM agg2 @@ -224,7 +223,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT valUe * key - 100 |FROM agg1 @@ -240,7 +239,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("test average no key in output") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT avg(value) |FROM agg1 @@ -251,7 +250,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("test average") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT key, avg(value) |FROM agg1 @@ -260,7 +259,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT avg(value), key |FROM agg1 @@ -269,7 +268,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT avg(value) + 1.5, key + 10 |FROM agg1 @@ -278,7 +277,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT avg(value) FROM agg1 """.stripMargin), @@ -287,7 +286,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("udaf") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | key, @@ -307,7 +306,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("non-AlgebraicAggregate aggreguate function") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT mydoublesum(value), key |FROM agg1 @@ -316,14 +315,14 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT mydoublesum(value) FROM agg1 """.stripMargin), Row(89.0) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT mydoublesum(null) """.stripMargin), @@ -332,7 +331,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT mydoublesum(value), key, avg(value) |FROM agg1 @@ -344,7 +343,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(30.0, null, 10.0) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | mydoublesum(value + 1.5 * key), @@ -364,7 +363,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("single distinct column set") { // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | min(distinct value1), @@ -377,7 +376,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(-60, 70.0, 101.0/9.0, 5.6, 100)) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | mydoubleavg(distinct value1), @@ -396,7 +395,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | key, @@ -414,7 +413,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | count(value1), @@ -433,7 +432,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("test count") { checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | count(value2), @@ -456,7 +455,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { Row(0, null, 1, 1, null) :: Nil) checkAnswer( - ctx.sql( + sqlContext.sql( """ |SELECT | count(value2), @@ -483,7 +482,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { test("error handling") { withSQLConf("spark.sql.useAggregate2" -> "false") { val errorMessage = intercept[AnalysisException] { - ctx.sql( + sqlContext.sql( """ |SELECT | key, @@ -501,7 +500,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { // we can remove the following two tests. withSQLConf("spark.sql.useAggregate2" -> "true") { val errorMessage = intercept[AnalysisException] { - ctx.sql( + sqlContext.sql( """ |SELECT | key, @@ -514,7 +513,7 @@ abstract class AggregationQuerySuite extends QueryTest with SharedHiveContext { assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) // This will fall back to the old aggregate - val newAggregateOperators = ctx.sql( + val newAggregateOperators = sqlContext.sql( """ |SELECT | key, @@ -539,17 +538,14 @@ class SortBasedAggregationQuerySuite extends AggregationQuerySuite { var originalUnsafeEnabled: Boolean = _ override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") super.beforeAll() - originalUnsafeEnabled = ctx.conf.unsafeEnabled - ctx.setConf(SQLConf.UNSAFE_ENABLED.key, "false") } override def afterAll(): Unit = { - try { - ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - } finally { - super.afterAll() - } + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } @@ -558,17 +554,14 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite { var originalUnsafeEnabled: Boolean = _ override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") super.beforeAll() - originalUnsafeEnabled = ctx.conf.unsafeEnabled - ctx.setConf(SQLConf.UNSAFE_ENABLED.key, "true") } override def afterAll(): Unit = { - try { - ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - } finally { - super.afterAll() - } + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } @@ -577,29 +570,26 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue var originalUnsafeEnabled: Boolean = _ override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") super.beforeAll() - originalUnsafeEnabled = ctx.conf.unsafeEnabled - ctx.setConf(SQLConf.UNSAFE_ENABLED.key, "true") } override def afterAll(): Unit = { - try { - ctx.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - ctx.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") - } finally { - super.afterAll() - } + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") } override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => - ctx.setConf( + sqlContext.setConf( "spark.sql.TungstenAggregate.testFallbackStartsAt", fallbackStartsAt.toString) // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. - val newActual = DataFrame(ctx, actual.logicalPlan) + val newActual = DataFrame(sqlContext, actual.logicalPlan) QueryTest.checkAnswer(newActual, expectedAnswer) match { case Some(errorMessage) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index 5182481297a8..a3f5921a0cb2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -19,26 +19,20 @@ package org.apache.spark.sql.hive.execution import java.io.File -import org.apache.spark.sql.hive.test.TestHiveContext.TestTable +import org.apache.spark.sql.hive.test.TestHive._ /** * A set of test cases based on the big-data-benchmark. * https://amplab.cs.berkeley.edu/benchmark/ */ class BigDataBenchmarkSuite extends HiveComparisonTest { + val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") - private val testDataDirectory = - new File("target" + File.separator + "big-data-benchmark-testdata") - private val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath - - protected override def beforeAll(): Unit = { - super.beforeAll() - val _ctx = ctx - import _ctx._ - val testTables = Seq( - TestTable( - "rankings", - s""" + val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath + val testTables = Seq( + TestTable( + "rankings", + s""" |CREATE EXTERNAL TABLE rankings ( | pageURL STRING, | pageRank INT, @@ -46,9 +40,9 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "rankings").getCanonicalPath}" """.stripMargin.cmd), - TestTable( - "scratch", - s""" + TestTable( + "scratch", + s""" |CREATE EXTERNAL TABLE scratch ( | pageURL STRING, | pageRank INT, @@ -56,9 +50,9 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "scratch").getCanonicalPath}" """.stripMargin.cmd), - TestTable( - "uservisits", - s""" + TestTable( + "uservisits", + s""" |CREATE EXTERNAL TABLE uservisits ( | sourceIP STRING, | destURL STRING, @@ -72,15 +66,15 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," | STORED AS TEXTFILE LOCATION "$userVisitPath" """.stripMargin.cmd), - TestTable( - "documents", - s""" + TestTable( + "documents", + s""" |CREATE EXTERNAL TABLE documents (line STRING) |STORED AS TEXTFILE |LOCATION "${new File(testDataDirectory, "crawl").getCanonicalPath}" """.stripMargin.cmd)) - testTables.foreach(registerTestTable) - } + + testTables.foreach(registerTestTable) if (!testDataDirectory.exists()) { // TODO: Auto download the files on demand. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index 6bae033be915..b0d3dd44daed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -19,23 +19,17 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext +import org.scalatest.BeforeAndAfterAll -class ConcurrentHiveSuite extends SparkFunSuite { +class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => - var sc: SparkContext = null - try { - sc = new SparkContext("local", s"TestSQLContext$i", new SparkConf()) - val ts = new TestHiveContext(sc) - ts.executeSql("SHOW TABLES").toRdd.collect() - ts.executeSql("SELECT * FROM src").toRdd.collect() - ts.executeSql("SHOW TABLES").toRdd.collect() - } finally { - if (sc != null) { - sc.stop() - } - } + val ts = + new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf())) + ts.executeSql("SHOW TABLES").toRdd.collect() + ts.executeSql("SELECT * FROM src").toRdd.collect() + ts.executeSql("SHOW TABLES").toRdd.collect() } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 6d391c57b8c5..2bdb0e11878e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.hive.execution import java.io._ -import org.scalatest.GivenWhenThen +import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive /** * Allows the creations of tests that execute the same query against both hive @@ -41,10 +40,7 @@ import org.apache.spark.sql.hive.test.SharedHiveContext * configured using system properties. */ abstract class HiveComparisonTest - extends SparkFunSuite - with GivenWhenThen - with SharedHiveContext - with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { /** * When set, any cache files that result in test failures will be deleted. Used when the test @@ -132,9 +128,9 @@ abstract class HiveComparisonTest new java.math.BigInteger(1, digest.digest).toString(16) } - private def prepareAnswer(ctx: SQLContext)( - hiveQuery: ctx.type#QueryExecution, - answer: Seq[String]): Seq[String] = { + protected def prepareAnswer( + hiveQuery: TestHive.type#QueryExecution, + answer: Seq[String]): Seq[String] = { def isSorted(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false @@ -271,11 +267,9 @@ abstract class HiveComparisonTest }.mkString("\n== Console version of this test ==\n", "\n", "\n") } - val _ctx = ctx - try { if (reset) { - ctx.reset() + TestHive.reset() } val hiveCacheFiles = queryList.zipWithIndex.map { @@ -304,7 +298,7 @@ abstract class HiveComparisonTest hiveCachedResults } else { - val hiveQueries = queryList.map(new _ctx.QueryExecution(_)) + val hiveQueries = queryList.map(new TestHive.QueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. // Note this must only look at the logical plan as we might not be able to analyze if // other DDL has not been executed yet. @@ -324,7 +318,7 @@ abstract class HiveComparisonTest case _: ExplainCommand => // No need to execute EXPLAIN queries as we don't check the output. Nil - case _ => ctx.runSqlHive(queryString) + case _ => TestHive.runSqlHive(queryString) } // We need to add a new line to non-empty answers so we can differentiate Seq() @@ -347,15 +341,15 @@ abstract class HiveComparisonTest fail(errorMessage) } }.toSeq - if (reset) { ctx.reset() } + if (reset) { TestHive.reset() } computedResults } // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new _ctx.QueryExecution(queryString) - try { (query, prepareAnswer(_ctx)(query, query.stringResult())) } catch { + val query = new TestHive.QueryExecution(queryString) + try { (query, prepareAnswer(query, query.stringResult())) } catch { case e: Throwable => val errorMessage = s""" @@ -374,7 +368,7 @@ abstract class HiveComparisonTest (queryList, hiveResults, catalystResults).zipped.foreach { case (query, hive, (hiveQuery, catalyst)) => // Check that the results match unless its an EXPLAIN query. - val preparedHive = prepareAnswer(_ctx)(hiveQuery, hive) + val preparedHive = prepareAnswer(hiveQuery, hive) // We will ignore the ExplainCommand, ShowFunctions, DescribeFunction if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && @@ -414,8 +408,8 @@ abstract class HiveComparisonTest // okay by running a simple query. If this fails then we halt testing since // something must have gone seriously wrong. try { - new _ctx.QueryExecution("SELECT key FROM src").stringResult() - ctx.runSqlHive("SELECT key FROM src") + new TestHive.QueryExecution("SELECT key FROM src").stringResult() + TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => logError(s"FATAL ERROR: Canary query threw $e This implies that the " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index cfc997395d3b..44c5b80392fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -17,22 +17,26 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.test.SQLTestUtils /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest with SharedHiveContext { +class HiveExplainSuite extends QueryTest with SQLTestUtils { + + def sqlContext: SQLContext = TestHive test("explain extended command") { - checkExistence(ctx.sql(" explain select * from src where key=123 "), true, + checkExistence(sql(" explain select * from src where key=123 "), true, "== Physical Plan ==") - checkExistence(ctx.sql(" explain select * from src where key=123 "), false, + checkExistence(sql(" explain select * from src where key=123 "), false, "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==") - checkExistence(ctx.sql(" explain extended select * from src where key=123 "), true, + checkExistence(sql(" explain extended select * from src where key=123 "), true, "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", @@ -41,15 +45,13 @@ class HiveExplainSuite extends QueryTest with SharedHiveContext { } test("explain create table command") { - checkExistence(ctx.sql("explain create table temp__b as select * from src limit 2"), true, + checkExistence(sql("explain create table temp__b as select * from src limit 2"), true, "== Physical Plan ==", "InsertIntoHiveTable", "Limit", "src") - checkExistence(ctx.sql( - "explain extended create table temp__b as select * from src limit 2"), - true, + checkExistence(sql("explain extended create table temp__b as select * from src limit 2"), true, "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", @@ -59,7 +61,7 @@ class HiveExplainSuite extends QueryTest with SharedHiveContext { "Limit", "src") - checkExistence(ctx.sql( + checkExistence(sql( """ | EXPLAIN EXTENDED CREATE TABLE temp__b | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" @@ -80,9 +82,9 @@ class HiveExplainSuite extends QueryTest with SharedHiveContext { test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { withTempTable("jt") { - val rdd = ctx.sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - ctx.read.json(rdd).registerTempTable("jt") - val outputs = ctx.sql( + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + read.json(rdd).registerTempTable("jt") + val outputs = sql( s""" |EXPLAIN EXTENDED |CREATE TABLE t1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index 63876dc5f0cb..efbef68cd444 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -18,33 +18,32 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive._ /** * A set of tests that validates commands can also be queried by like a table */ -class HiveOperatorQueryableSuite extends QueryTest with SharedHiveContext { - +class HiveOperatorQueryableSuite extends QueryTest { test("SPARK-5324 query result of describe command") { - ctx.loadTestTable("src") + loadTestTable("src") // register a describe command to be a temp table - ctx.sql("desc src").registerTempTable("mydesc") + sql("desc src").registerTempTable("mydesc") checkAnswer( - ctx.sql("desc mydesc"), + sql("desc mydesc"), Seq( Row("col_name", "string", "name of the column"), Row("data_type", "string", "data type of the column"), Row("comment", "string", "comment of the column"))) checkAnswer( - ctx.sql("select * from mydesc"), + sql("select * from mydesc"), Seq( Row("key", "int", null), Row("value", "string", null))) checkAnswer( - ctx.sql("select col_name, data_type, comment from mydesc"), + sql("select col_name, data_type, comment from mydesc"), Seq( Row("key", "int", null), Row("value", "string", null))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index cd1bfa43ecca..ba56a8a6b689 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -21,15 +21,16 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive -class HivePlanTest extends QueryTest with SharedHiveContext { - import testImplicits._ +class HivePlanTest extends QueryTest { + import TestHive._ + import TestHive.implicits._ test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") - val optimized = ctx.sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan - val correctAnswer = ctx.sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan + val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 0ccc95227ab6..83f9f3eaa3a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -22,50 +22,49 @@ import java.util.{Locale, TimeZone} import scala.util.Try +import org.scalatest.BeforeAndAfter + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.test.SQLTestData.TestData +import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +case class TestData(a: Int, b: String) /** * A set of test cases expressed in Hive QL that are not covered by the tests * included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest { - import testImplicits._ - +class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault - override def beforeAll(): Unit = { - super.beforeAll() - ctx.cacheTables = true + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + override def beforeAll() { + TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) } - override def afterAll(): Unit = { - try { - ctx.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - ctx.sql("DROP TEMPORARY FUNCTION udtf_count2") - } finally { - super.afterAll() - } + override def afterAll() { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + sql("DROP TEMPORARY FUNCTION udtf_count2") } test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => - ctx.sql("USE default") - ctx.sql("SHOW DATABASES") + sql("USE default") + sql("SHOW DATABASES") } } @@ -149,11 +148,11 @@ class HiveQuerySuite extends HiveComparisonTest { test("multiple generators in projection") { intercept[AnalysisException] { - ctx.sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() + sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() } intercept[AnalysisException] { - ctx.sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() + sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() } } @@ -243,8 +242,8 @@ class HiveQuerySuite extends HiveComparisonTest { """.stripMargin) test("CREATE TABLE AS runs once") { - ctx.sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() - assert(ctx.sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() + assert(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, "Incorrect number of rows in created table") } @@ -256,7 +255,7 @@ class HiveQuerySuite extends HiveComparisonTest { // Jdk version leads to different query output for double, so not use createQueryTest here test("division") { - val res = ctx.sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head + val res = sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res.toSeq).foreach( x => assert(x._1 == x._2.asInstanceOf[Double])) } @@ -266,17 +265,17 @@ class HiveQuerySuite extends HiveComparisonTest { "(101 / 2) % 10 FROM src LIMIT 1") test("Query expressed in SQL") { - ctx.setConf("spark.sql.dialect", "sql") - assert(ctx.sql("SELECT 1").collect() === Array(Row(1))) - ctx.setConf("spark.sql.dialect", "hiveql") + setConf("spark.sql.dialect", "sql") + assert(sql("SELECT 1").collect() === Array(Row(1))) + setConf("spark.sql.dialect", "hiveql") } test("Query expressed in HiveQL") { - ctx.sql("FROM src SELECT key").collect() + sql("FROM src SELECT key").collect() } test("Query with constant folding the CAST") { - ctx.sql("SELECT CAST(CAST('123' AS binary) AS binary) FROM src LIMIT 1").collect() + sql("SELECT CAST(CAST('123' AS binary) AS binary) FROM src LIMIT 1").collect() } createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT", @@ -375,10 +374,10 @@ class HiveQuerySuite extends HiveComparisonTest { """.stripMargin) test("SPARK-7270: consider dynamic partition when comparing table output") { - ctx.sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") - ctx.sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") + sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") + sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") - val analyzedPlan = ctx.sql( + val analyzedPlan = sql( """ |INSERT OVERWRITE table test_partition PARTITION (b=1, c) |SELECT 'a', 'c' from ptest @@ -432,11 +431,11 @@ class HiveQuerySuite extends HiveComparisonTest { test("transform with SerDe2") { - ctx.sql("CREATE TABLE small_src(key INT, value STRING)") - ctx.sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10") + sql("CREATE TABLE small_src(key INT, value STRING)") + sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10") - val expected = ctx.sql("SELECT key FROM small_src").collect().head - val res = ctx.sql( + val expected = sql("SELECT key FROM small_src").collect().head + val res = sql( """ |SELECT TRANSFORM (key) ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.avro.AvroSerDe' @@ -510,13 +509,13 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") test("sampling") { - ctx.sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") - ctx.sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") + sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") } test("DataFrame toString") { - ctx.sql("SHOW TABLES").toString - ctx.sql("SELECT * FROM src").toString + sql("SHOW TABLES").toString + sql("SELECT * FROM src").toString } createQueryTest("case statements with key #1", @@ -545,7 +544,7 @@ class HiveQuerySuite extends HiveComparisonTest { // Jdk version leads to different query output for double, so not use createQueryTest here test("timestamp cast #1") { - val res = ctx.sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head assert(0.001 == res.getDouble(0)) } @@ -640,7 +639,7 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT get_json_object(src_json.json, '$.fb:testid') FROM src_json") test("predicates contains an empty AttributeSet() references") { - ctx.sql( + sql( """ |SELECT a FROM ( | SELECT 1 AS a FROM src LIMIT 1 ) t @@ -649,12 +648,12 @@ class HiveQuerySuite extends HiveComparisonTest { } test("implement identity function using case statement") { - val actual = ctx.sql("SELECT (CASE key WHEN key THEN key END) FROM src") + val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") .map { case Row(i: Int) => i } .collect() .toSet - val expected = ctx.sql("SELECT key FROM src") + val expected = sql("SELECT key FROM src") .map { case Row(i: Int) => i } .collect() .toSet @@ -666,7 +665,7 @@ class HiveQuerySuite extends HiveComparisonTest { // See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion. ignore("non-boolean conditions in a CaseWhen are illegal") { intercept[Exception] { - ctx.sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() + sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() } } @@ -675,13 +674,13 @@ class HiveQuerySuite extends HiveComparisonTest { test("case sensitivity: registered table") { val testData = - ctx.sparkContext.parallelize( + TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) testData.toDF().registerTempTable("REGisteredTABle") assertResult(Array(Row(2, "str2"))) { - ctx.sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + + sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + "WHERE TableAliaS.a > 1").collect() } } @@ -692,94 +691,92 @@ class HiveQuerySuite extends HiveComparisonTest { } test("SPARK-1704: Explain commands as a DataFrame") { - ctx.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - val df = ctx.sql("explain select key, count(value) from src group by key") + val df = sql("explain select key, count(value) from src group by key") assert(isExplanation(df)) - ctx.reset() + TestHive.reset() } test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - ctx.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") + TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") val results = - ctx.sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() .map(x => Pair(x.getString(0), x.getInt(1))) assert(results === Array(Pair("foo", 4))) - ctx.reset() + TestHive.reset() } test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { - ctx.sql("select key, count(*) c from src group by key having c").collect() + sql("select key, count(*) c from src group by key having c").collect() } test("SPARK-2225: turn HAVING without GROUP BY into a simple filter") { - assert(ctx.sql("select key from src having key > 490").collect().size < 100) + assert(sql("select key from src having key > 490").collect().size < 100) } test("SPARK-5383 alias for udfs with multi output columns") { assert( - ctx.sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") + sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") .collect() .size == 5) assert( - ctx.sql( - "select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") + sql("select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") .collect() .size == 5) } test("SPARK-5367: resolve star expression in udf") { - assert(ctx.sql("select concat(*) from src limit 5").collect().size == 5) - assert(ctx.sql("select array(*) from src limit 5").collect().size == 5) - assert(ctx.sql("select concat(key, *) from src limit 5").collect().size == 5) - assert(ctx.sql("select array(key, *) from src limit 5").collect().size == 5) + assert(sql("select concat(*) from src limit 5").collect().size == 5) + assert(sql("select array(*) from src limit 5").collect().size == 5) + assert(sql("select concat(key, *) from src limit 5").collect().size == 5) + assert(sql("select array(key, *) from src limit 5").collect().size == 5) } test("Query Hive native command execution result") { val databaseName = "test_native_commands" assertResult(0) { - ctx.sql(s"DROP DATABASE IF EXISTS $databaseName").count() + sql(s"DROP DATABASE IF EXISTS $databaseName").count() } assertResult(0) { - ctx.sql(s"CREATE DATABASE $databaseName").count() + sql(s"CREATE DATABASE $databaseName").count() } assert( - ctx.sql("SHOW DATABASES") + sql("SHOW DATABASES") .select('result) .collect() .map(_.getString(0)) .contains(databaseName)) - assert(isExplanation(ctx.sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key"))) + assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key"))) - ctx.reset() + TestHive.reset() } test("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" - val q0 = ctx.sql(s"CREATE TABLE $tableName(key INT, value STRING)") + val q0 = sql(s"CREATE TABLE $tableName(key INT, value STRING)") // If the table was not created, the following assertion would fail - assert(Try(ctx.table(tableName)).isSuccess) + assert(Try(table(tableName)).isSuccess) // If the CREATE TABLE command got executed again, the following assertion would fail assert(Try(q0.count()).isSuccess) } test("DESCRIBE commands") { - ctx.sql( - s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") + sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") - ctx.sql( + sql( """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') |SELECT key, value """.stripMargin) @@ -794,7 +791,7 @@ class HiveQuerySuite extends HiveComparisonTest { Row("# col_name", "data_type", "comment"), Row("dt", "string", null)) ) { - ctx.sql("DESCRIBE test_describe_commands1") + sql("DESCRIBE test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } @@ -809,14 +806,14 @@ class HiveQuerySuite extends HiveComparisonTest { Row("# col_name", "data_type", "comment"), Row("dt", "string", null)) ) { - ctx.sql("DESCRIBE default.test_describe_commands1") + sql("DESCRIBE default.test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - ctx.sql("DESCRIBE test_describe_commands1 value") + sql("DESCRIBE test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -824,7 +821,7 @@ class HiveQuerySuite extends HiveComparisonTest { // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - ctx.sql("DESCRIBE default.test_describe_commands1 value") + sql("DESCRIBE default.test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -842,7 +839,7 @@ class HiveQuerySuite extends HiveComparisonTest { Array(""), Array("dt", "string")) ) { - ctx.sql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") + sql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") .select('result) .collect() .map(_.getString(0).replaceAll("None", "").trim.split("\t").map(_.trim)) @@ -850,7 +847,7 @@ class HiveQuerySuite extends HiveComparisonTest { // Describe a registered temporary table. val testData = - ctx.sparkContext.parallelize( + TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) testData.toDF().registerTempTable("test_describe_commands2") @@ -860,16 +857,16 @@ class HiveQuerySuite extends HiveComparisonTest { Row("a", "int", ""), Row("b", "string", "")) ) { - ctx.sql("DESCRIBE test_describe_commands2") + sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) .collect() } } test("SPARK-2263: Insert Map values") { - ctx.sql("CREATE TABLE m(value MAP)") - ctx.sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - ctx.sql("SELECT * FROM m").collect().zip(ctx.sql("SELECT * FROM src LIMIT 10").collect()).map { + sql("CREATE TABLE m(value MAP)") + sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -877,35 +874,35 @@ class HiveQuerySuite extends HiveComparisonTest { } test("ADD JAR command") { - val testJar = TestHiveContext.getHiveFile("data/files/TestSerDe.jar").getCanonicalPath - ctx.sql("CREATE TABLE alter1(a INT, b INT)") + val testJar = TestHive.getHiveFile("data/files/TestSerDe.jar").getCanonicalPath + sql("CREATE TABLE alter1(a INT, b INT)") intercept[Exception] { - ctx.sql( + sql( """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' |WITH serdeproperties('s1'='9') """.stripMargin) } - ctx.sql("DROP TABLE alter1") + sql("DROP TABLE alter1") } test("ADD JAR command 2") { // this is a test case from mapjoin_addjar.q - val testJar = TestHiveContext.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath - val testData = TestHiveContext.getHiveFile("data/files/sample.json").getCanonicalPath - ctx.sql(s"ADD JAR $testJar") - ctx.sql( + val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath + sql(s"ADD JAR $testJar") + sql( """CREATE TABLE t1(a string, b string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) - ctx.sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") - ctx.sql("select * from src join t1 on src.key = t1.a") - ctx.sql("DROP TABLE t1") + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") } test("ADD FILE command") { - val testFile = TestHiveContext.getHiveFile("data/files/v1.txt").getCanonicalFile - ctx.sql(s"ADD FILE $testFile") + val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile + sql(s"ADD FILE $testFile") - val checkAddFileRDD = ctx.sparkContext.parallelize(1 to 2, 1).mapPartitions { _ => + val checkAddFileRDD = sparkContext.parallelize(1 to 2, 1).mapPartitions { _ => Iterator.single(new File(SparkFiles.get("v1.txt")).canRead) } @@ -938,10 +935,9 @@ class HiveQuerySuite extends HiveComparisonTest { """.stripMargin) ignore("Dynamic partition folder layout") { - ctx.sql("DROP TABLE IF EXISTS dynamic_part_table") - ctx.sql( - "CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") - ctx.sql("SET hive.exec.dynamic.partition.mode=nonstrict") + sql("DROP TABLE IF EXISTS dynamic_part_table") + sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") val data = Map( Seq("1", "1") -> 1, @@ -950,7 +946,7 @@ class HiveQuerySuite extends HiveComparisonTest { Seq("NULL", "NULL") -> 4) data.foreach { case (parts, value) => - ctx.sql( + sql( s"""INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) |SELECT $value, ${parts.mkString(", ")} FROM src WHERE key=150 """.stripMargin) @@ -967,18 +963,18 @@ class HiveQuerySuite extends HiveComparisonTest { .mkString("/") // Loads partition data to a temporary table to verify contents - val path = s"${ctx.warehousePath}/dynamic_part_table/$partFolder/part-00000" + val path = s"$warehousePath/dynamic_part_table/$partFolder/part-00000" - ctx.sql("DROP TABLE IF EXISTS dp_verify") - ctx.sql("CREATE TABLE dp_verify(intcol INT)") - ctx.sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") + sql("DROP TABLE IF EXISTS dp_verify") + sql("CREATE TABLE dp_verify(intcol INT)") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") - assert(ctx.sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) } } test("SPARK-5592: get java.net.URISyntaxException when dynamic partitioning") { - ctx.sql(""" + sql(""" |create table sc as select * |from (select '2011-01-11', '2011-01-11+14:18:26' from src tablesample (1 rows) |union all @@ -986,31 +982,31 @@ class HiveQuerySuite extends HiveComparisonTest { |union all |select '2011-01-11', '2011-01-11+16:18:26' from src tablesample (1 rows) ) s """.stripMargin) - ctx.sql("create table sc_part (key string) partitioned by (ts string) stored as rcfile") - ctx.sql("set hive.exec.dynamic.partition=true") - ctx.sql("set hive.exec.dynamic.partition.mode=nonstrict") - ctx.sql("insert overwrite table sc_part partition(ts) select * from sc") - ctx.sql("drop table sc_part") + sql("create table sc_part (key string) partitioned by (ts string) stored as rcfile") + sql("set hive.exec.dynamic.partition=true") + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("insert overwrite table sc_part partition(ts) select * from sc") + sql("drop table sc_part") } test("Partition spec validation") { - ctx.sql("DROP TABLE IF EXISTS dp_test") - ctx.sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") - ctx.sql("SET hive.exec.dynamic.partition.mode=strict") + sql("DROP TABLE IF EXISTS dp_test") + sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") + sql("SET hive.exec.dynamic.partition.mode=strict") // Should throw when using strict dynamic partition mode without any static partition intercept[SparkException] { - ctx.sql( + sql( """INSERT INTO TABLE dp_test PARTITION(dp) |SELECT key, value, key % 5 FROM src """.stripMargin) } - ctx.sql("SET hive.exec.dynamic.partition.mode=nonstrict") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") // Should throw when a static partition appears after a dynamic partition intercept[SparkException] { - ctx.sql( + sql( """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) |SELECT key, value, key % 5 FROM src """.stripMargin) @@ -1018,10 +1014,10 @@ class HiveQuerySuite extends HiveComparisonTest { } test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { - ctx.sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().registerTempTable("rawLogs") - ctx.sparkContext.makeRDD(Seq.empty[LogFile]).toDF().registerTempTable("logFiles") + sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().registerTempTable("rawLogs") + sparkContext.makeRDD(Seq.empty[LogFile]).toDF().registerTempTable("logFiles") - ctx.sql( + sql( """ SELECT name, message FROM rawLogs @@ -1033,15 +1029,15 @@ class HiveQuerySuite extends HiveComparisonTest { """).registerTempTable("boom") // This should be successfully analyzed - ctx.sql("SELECT * FROM boom").queryExecution.analyzed + sql("SELECT * FROM boom").queryExecution.analyzed } test("SPARK-3810: PreInsertionCasts static partitioning support") { val analyzedPlan = { - ctx.loadTestTable("srcpart") - ctx.sql("DROP TABLE IF EXISTS withparts") - ctx.sql("CREATE TABLE withparts LIKE srcpart") - ctx.sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") .queryExecution.analyzed } @@ -1054,13 +1050,13 @@ class HiveQuerySuite extends HiveComparisonTest { test("SPARK-3810: PreInsertionCasts dynamic partitioning support") { val analyzedPlan = { - ctx.loadTestTable("srcpart") - ctx.sql("DROP TABLE IF EXISTS withparts") - ctx.sql("CREATE TABLE withparts LIKE srcpart") - ctx.sql("SET hive.exec.dynamic.partition.mode=nonstrict") + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") - ctx.sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") - ctx.sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") .queryExecution.analyzed } @@ -1076,19 +1072,19 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "val0,val_1,val2.3,my_table" - ctx.sql(s"set $testKey=$testVal") - assert(ctx.getConf(testKey, testVal + "_") == testVal) + sql(s"set $testKey=$testVal") + assert(getConf(testKey, testVal + "_") == testVal) - ctx.sql("set some.property=20") - assert(ctx.getConf("some.property", "0") == "20") - ctx.sql("set some.property = 40") - assert(ctx.getConf("some.property", "0") == "40") + sql("set some.property=20") + assert(getConf("some.property", "0") == "20") + sql("set some.property = 40") + assert(getConf("some.property", "0") == "40") - ctx.sql(s"set $testKey=$testVal") - assert(ctx.getConf(testKey, "0") == testVal) + sql(s"set $testKey=$testVal") + assert(getConf(testKey, "0") == testVal) - ctx.sql(s"set $testKey=") - assert(ctx.getConf(testKey, "0") == "") + sql(s"set $testKey=") + assert(getConf(testKey, "0") == "") } test("SET commands semantics for a HiveContext") { @@ -1101,38 +1097,38 @@ class HiveQuerySuite extends HiveComparisonTest { case Row(key: String, value: String) => key -> value case Row(key: String, defaultValue: String, doc: String) => (key, defaultValue, doc) }.toSet - ctx.conf.clear() + conf.clear() - val expectedConfs = ctx.conf.getAllDefinedConfs.toSet - assertResult(expectedConfs)(collectResults(ctx.sql("SET -v"))) + val expectedConfs = conf.getAllDefinedConfs.toSet + assertResult(expectedConfs)(collectResults(sql("SET -v"))) // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... - assert(ctx.sql("SET").collect().size == 0) + assert(sql("SET").collect().size == 0) assertResult(Set(testKey -> testVal)) { - collectResults(ctx.sql(s"SET $testKey=$testVal")) + collectResults(sql(s"SET $testKey=$testVal")) } - assert(ctx.hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal))(collectResults(ctx.sql("SET"))) + assert(hiveconf.get(testKey, "") == testVal) + assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) - ctx.sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(ctx.hiveconf.get(testKey + testKey, "") == testVal + testVal) + sql(s"SET ${testKey + testKey}=${testVal + testVal}") + assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(ctx.sql("SET")) + collectResults(sql("SET")) } // "SET key" assertResult(Set(testKey -> testVal)) { - collectResults(ctx.sql(s"SET $testKey")) + collectResults(sql(s"SET $testKey")) } assertResult(Set(nonexistentKey -> "")) { - collectResults(ctx.sql(s"SET $nonexistentKey")) + collectResults(sql(s"SET $nonexistentKey")) } - ctx.conf.clear() + conf.clear() } createQueryTest("select from thrift based table", @@ -1143,4 +1139,4 @@ class HiveQuerySuite extends HiveComparisonTest { } // for SPARK-2180 test -private case class HavingRow(key: Int, value: String, attr: Int) +case class HavingRow(key: Int, value: String, attr: Int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index fa69c3b84c02..b08db6de2d2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) @@ -27,22 +29,21 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) * included in the hive distribution. */ class HiveResolutionSuite extends HiveComparisonTest { - import testImplicits._ test("SPARK-3698: case insensitive test for nested data") { - ctx.read.json(ctx.sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested") // This should be successfully analyzed - ctx.sql("SELECT a[0].A.A from nested").queryExecution.analyzed + sql("SELECT a[0].A.A from nested").queryExecution.analyzed } test("SPARK-5278: check ambiguous reference to fields") { - ctx.read.json(ctx.sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested") // there are 2 filed matching field name "b", we should report Ambiguous reference error val exception = intercept[AnalysisException] { - ctx.sql("SELECT a[0].b from nested").queryExecution.analyzed + sql("SELECT a[0].b from nested").queryExecution.analyzed } assert(exception.getMessage.contains("Ambiguous reference to fields")) } @@ -76,10 +77,10 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - ctx.sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") - val query = ctx.sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), "The output schema did not preserve the case of the query.") query.collect() @@ -87,16 +88,16 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection - ctx.sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") - ctx.sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() + sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { - ctx.sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("nestedRepeatedTest") - assert(ctx.sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) + assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } createQueryTest("test ambiguousReferences resolved as hive", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 2d784f8062a0..5586a793618b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.hive.test.TestHiveContext +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive /** * A set of tests that validates support for Hive SerDe. */ -class HiveSerDeSuite extends HiveComparisonTest { - import org.apache.hadoop.hive.serde2.RegexSerDe - +class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { override def beforeAll(): Unit = { - super.beforeAll() - ctx.cacheTables = false - ctx.sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + import TestHive._ + import org.apache.hadoop.hive.serde2.RegexSerDe + super.beforeAll() + TestHive.cacheTables = false + sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") """.stripMargin) - val dataFile = TestHiveContext.getHiveFile("data/files/sales.txt") - ctx.sql(s"LOAD DATA LOCAL INPATH '$dataFile' INTO TABLE sales") + sql(s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales") } // table sales is not a cache table, and will be clear after reset diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 5e8f9f961b89..2209fc2f30a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,10 +18,14 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + import org.apache.spark.util.Utils class HiveTableScanSuite extends HiveComparisonTest { - import testImplicits._ createQueryTest("partition_based_table_scan_with_different_serde", """ @@ -52,15 +56,15 @@ class HiveTableScanSuite extends HiveComparisonTest { """.stripMargin) test("Spark-4041: lowercase issue") { - ctx.sql("CREATE TABLE tb (KEY INT, VALUE STRING) STORED AS ORC") - ctx.sql("insert into table tb select key, value from src") - ctx.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() - ctx.sql("drop table tb") + TestHive.sql("CREATE TABLE tb (KEY INT, VALUE STRING) STORED AS ORC") + TestHive.sql("insert into table tb select key, value from src") + TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() + TestHive.sql("drop table tb") } test("Spark-4077: timestamp query for null value") { - ctx.sql("DROP TABLE IF EXISTS timestamp_query_null") - ctx.sql( + TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") + TestHive.sql( """ CREATE EXTERNAL TABLE timestamp_query_null (time TIMESTAMP,id INT) ROW FORMAT DELIMITED @@ -70,20 +74,20 @@ class HiveTableScanSuite extends HiveComparisonTest { val location = Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() - ctx.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") - assert(ctx.sql("SELECT time from timestamp_query_null limit 2").collect() + TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") + assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) - ctx.sql("DROP TABLE timestamp_query_null") + TestHive.sql("DROP TABLE timestamp_query_null") } test("Spark-4959 Attributes are case sensitive when using a select query from a projection") { - ctx.sql("create table spark_4959 (col1 string)") - ctx.sql("""insert into table spark_4959 select "hi" from src limit 1""") - ctx.table("spark_4959").select( + sql("create table spark_4959 (col1 string)") + sql("""insert into table spark_4959 select "hi" from src limit 1""") + table("spark_4959").select( 'col1.as("CaseSensitiveColName"), 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2") - assert(ctx.sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) - assert(ctx.sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) + assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) + assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 4ce21258d875..197e9bfb02c4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.hive.test.TestHive /** * A set of tests that validate type promotion and coercion rules. @@ -42,7 +43,7 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = ctx.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 8fbeb3498fef..10f2902e5eef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -21,8 +21,6 @@ import java.io.{DataInput, DataOutput} import java.util import java.util.Properties -import scala.collection.JavaConversions._ - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject @@ -30,11 +28,13 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable - import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive + import org.apache.spark.util.Utils +import scala.collection.JavaConversions._ + case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) // Case classes for the custom UDF's. @@ -46,12 +46,14 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest with SharedHiveContext { - import testImplicits._ +class HiveUDFSuite extends QueryTest { + + import TestHive.{udf, sql} + import TestHive.implicits._ test("spark sql udf test that returns a struct") { - ctx.udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) - assert(ctx.sql( + udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) + assert(sql( """ |SELECT getStruct(1).f1, | getStruct(1).f2, @@ -63,13 +65,13 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { checkAnswer( - ctx.sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), + sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), Row(8) ) } test("hive struct udf") { - ctx.sql( + sql( """ |CREATE EXTERNAL TABLE hiveUDFTestTable ( | pair STRUCT @@ -81,25 +83,25 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { stripMargin.format(classOf[PairSerDe].getName)) val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile - ctx.sql(s""" + sql(s""" ALTER TABLE hiveUDFTestTable ADD IF NOT EXISTS PARTITION(partition='testUDF') LOCATION '$location'""") - ctx.sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") - ctx.sql("SELECT testUDF(pair) FROM hiveUDFTestTable") - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") } test("Max/Min on named_struct") { def testOrderInStruct(): Unit = { - checkAnswer(ctx.sql( + checkAnswer(sql( """ |SELECT max(named_struct( | "key", key, | "value", value)).value FROM src """.stripMargin), Seq(Row("val_498"))) - checkAnswer(ctx.sql( + checkAnswer(sql( """ |SELECT min(named_struct( | "key", key, @@ -107,7 +109,7 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { """.stripMargin), Seq(Row("val_0"))) // nested struct cases - checkAnswer(ctx.sql( + checkAnswer(sql( """ |SELECT max(named_struct( | "key", named_struct( @@ -115,7 +117,7 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { "value", value), | "value", value)).value FROM src """.stripMargin), Seq(Row("val_498"))) - checkAnswer(ctx.sql( + checkAnswer(sql( """ |SELECT min(named_struct( | "key", named_struct( @@ -124,178 +126,176 @@ class HiveUDFSuite extends QueryTest with SharedHiveContext { | "value", value)).value FROM src """.stripMargin), Seq(Row("val_0"))) } - val codegenDefault = ctx.getConf(SQLConf.CODEGEN_ENABLED) - ctx.setConf(SQLConf.CODEGEN_ENABLED, true) + val codegenDefault = TestHive.getConf(SQLConf.CODEGEN_ENABLED) + TestHive.setConf(SQLConf.CODEGEN_ENABLED, true) testOrderInStruct() - ctx.setConf(SQLConf.CODEGEN_ENABLED, false) + TestHive.setConf(SQLConf.CODEGEN_ENABLED, false) testOrderInStruct() - ctx.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + TestHive.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) } test("SPARK-6409 UDAFAverage test") { - ctx.sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") checkAnswer( - ctx.sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), + sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), Seq(Row(1.0, 260.182))) - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") - ctx.reset() + sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") + TestHive.reset() } test("SPARK-2693 udaf aggregates test") { - checkAnswer(ctx.sql("SELECT percentile(key, 1) FROM src LIMIT 1"), - ctx.sql("SELECT max(key) FROM src").collect().toSeq) + checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src").collect().toSeq) - checkAnswer(ctx.sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), - ctx.sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) + checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), + sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) } test("Generic UDAF aggregates") { - checkAnswer(ctx.sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), - ctx.sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) + checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) - checkAnswer(ctx.sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), - ctx.sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) + checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), + sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) } test("UDFIntegerToString") { - val testData = ctx.sparkContext.parallelize( + val testData = TestHive.sparkContext.parallelize( IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") val udfName = classOf[UDFIntegerToString].getName - ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") checkAnswer( - ctx.sql("SELECT testUDFIntegerToString(i) FROM integerTable"), + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), Seq(Row("1"), Row("2"))) - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") - ctx.reset() + TestHive.reset() } test("UDFToListString") { - val testData = ctx.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") - ctx.sql( - s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") val errMsg = intercept[AnalysisException] { - ctx.sql("SELECT testUDFToListString(s) FROM inputTable") + sql("SELECT testUDFToListString(s) FROM inputTable") } assert(errMsg.getMessage contains "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") - ctx.reset() + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") + TestHive.reset() } test("UDFToListInt") { - val testData = ctx.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") - ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") val errMsg = intercept[AnalysisException] { - ctx.sql("SELECT testUDFToListInt(s) FROM inputTable") + sql("SELECT testUDFToListInt(s) FROM inputTable") } assert(errMsg.getMessage contains "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") - ctx.reset() + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") + TestHive.reset() } test("UDFToStringIntMap") { - val testData = ctx.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") - ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + + sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + s"AS '${classOf[UDFToStringIntMap].getName}'") val errMsg = intercept[AnalysisException] { - ctx.sql("SELECT testUDFToStringIntMap(s) FROM inputTable") + sql("SELECT testUDFToStringIntMap(s) FROM inputTable") } assert(errMsg.getMessage contains "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") - ctx.reset() + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") + TestHive.reset() } test("UDFToIntIntMap") { - val testData = ctx.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") - ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + + sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + s"AS '${classOf[UDFToIntIntMap].getName}'") val errMsg = intercept[AnalysisException] { - ctx.sql("SELECT testUDFToIntIntMap(s) FROM inputTable") + sql("SELECT testUDFToIntIntMap(s) FROM inputTable") } assert(errMsg.getMessage contains "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") - ctx.reset() + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") + TestHive.reset() } test("UDFListListInt") { - val testData = ctx.sparkContext.parallelize( + val testData = TestHive.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() testData.registerTempTable("listListIntTable") - ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( - ctx.sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), Seq(Row(0), Row(2), Row(13))) - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") - ctx.reset() + TestHive.reset() } test("UDFListString") { - val testData = ctx.sparkContext.parallelize( + val testData = TestHive.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() testData.registerTempTable("listStringTable") - ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( - ctx.sql("SELECT testUDFListString(l) FROM listStringTable"), + sql("SELECT testUDFListString(l) FROM listStringTable"), Seq(Row("a,b,c"), Row("d,e"))) - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") - ctx.reset() + TestHive.reset() } test("UDFStringString") { - val testData = ctx.sparkContext.parallelize( + val testData = TestHive.sparkContext.parallelize( StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") - ctx.sql( - s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( - ctx.sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), + sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") - ctx.reset() + TestHive.reset() } test("UDFTwoListList") { - val testData = ctx.sparkContext.parallelize( + val testData = TestHive.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() testData.registerTempTable("TwoListTable") - ctx.sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( - ctx.sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) - ctx.sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - ctx.reset() + TestHive.reset() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 34ea48a18310..3bf8f3ac2048 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -17,22 +17,23 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.hive.test.TestHive + /* Implicit conversions */ import scala.collection.JavaConversions._ /** * A set of test cases that validate partition and column pruning. */ -class PruningSuite extends HiveComparisonTest { - - protected override def beforeAll(): Unit = { - super.beforeAll() - ctx.cacheTables = false - // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need - // to reset the environment to ensure all referenced tables in this suites are not cached - // in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. - ctx.reset() - } +class PruningSuite extends HiveComparisonTest with BeforeAndAfter { + TestHive.cacheTables = false + + // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset + // the environment to ensure all referenced tables in this suites are not cached in-memory. + // Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. + TestHive.reset() // Column pruning tests @@ -144,8 +145,7 @@ class PruningSuite extends HiveComparisonTest { expectedScannedColumns: Seq[String], expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { - val _ctx = ctx - val plan = new _ctx.QueryExecution(sql).executedPlan + val plan = new TestHive.QueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 75eb0c4c8016..79a136ae6f61 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,10 +26,12 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.{SharedHiveContext, TestHiveContext} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.test.SQLTestData.TestData +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -63,32 +65,32 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest with SharedHiveContext { - import testImplicits._ +class SQLQuerySuite extends QueryTest with SQLTestUtils { + override def sqlContext: SQLContext = TestHive test("UDTF") { - ctx.sql(s"ADD JAR ${TestHiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") // The function source code can be found at: // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - ctx.sql( + sql( """ |CREATE TEMPORARY FUNCTION udtf_count2 |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' """.stripMargin) checkAnswer( - ctx.sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), Row(97, 500) :: Row(97, 500) :: Nil) checkAnswer( - ctx.sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), Row(3) :: Row(3) :: Nil) } test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") - val query = ctx.sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") + val query = sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") checkAnswer(query, Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) } @@ -113,7 +115,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { orders.toDF.registerTempTable("orders1") orderUpdates.toDF.registerTempTable("orderupdates1") - ctx.sql( + sql( """CREATE TABLE orders( | id INT, | make String, @@ -126,7 +128,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { |STORED AS PARQUET """.stripMargin) - ctx.sql( + sql( """CREATE TABLE orderupdates( | id INT, | make String, @@ -139,12 +141,12 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { |STORED AS PARQUET """.stripMargin) - ctx.sql("set hive.exec.dynamic.partition.mode=nonstrict") - ctx.sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") - ctx.sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") + sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") checkAnswer( - ctx.sql( + sql( """ |select orders.state, orders.month |from orders @@ -162,22 +164,22 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { val allFunctions = (FunctionRegistry.builtin.listFunction().toSet[String] ++ org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames).toList.sorted - checkAnswer(ctx.sql("SHOW functions"), allFunctions.map(Row(_))) - checkAnswer(ctx.sql("SHOW functions abs"), Row("abs")) - checkAnswer(ctx.sql("SHOW functions 'abs'"), Row("abs")) - checkAnswer(ctx.sql("SHOW functions abc.abs"), Row("abs")) - checkAnswer(ctx.sql("SHOW functions `abc`.`abs`"), Row("abs")) - checkAnswer(ctx.sql("SHOW functions `abc`.`abs`"), Row("abs")) - checkAnswer(ctx.sql("SHOW functions `~`"), Row("~")) - checkAnswer(ctx.sql("SHOW functions `a function doens't exist`"), Nil) - checkAnswer(ctx.sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) + checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) + checkAnswer(sql("SHOW functions abs"), Row("abs")) + checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) + checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `~`"), Row("~")) + checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) + checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) // this probably will failed if we add more function with `sha` prefixing. - checkAnswer(ctx.sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) } test("describe functions") { // The Spark SQL built-in functions - checkExistence(ctx.sql("describe function extended upper"), true, + checkExistence(sql("describe function extended upper"), true, "Function: upper", "Class: org.apache.spark.sql.catalyst.expressions.Upper", "Usage: upper(str) - Returns str with all characters changed to uppercase", @@ -185,18 +187,18 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { "> SELECT upper('SparkSql')", "'SPARKSQL'") - checkExistence(ctx.sql("describe functioN Upper"), true, + checkExistence(sql("describe functioN Upper"), true, "Function: upper", "Class: org.apache.spark.sql.catalyst.expressions.Upper", "Usage: upper(str) - Returns str with all characters changed to uppercase") - checkExistence(ctx.sql("describe functioN Upper"), false, + checkExistence(sql("describe functioN Upper"), false, "Extended Usage") - checkExistence(ctx.sql("describe functioN abcadf"), true, + checkExistence(sql("describe functioN abcadf"), true, "Function: abcadf is not found.") - checkExistence(ctx.sql("describe functioN `~`"), true, + checkExistence(sql("describe functioN `~`"), true, "Function: ~", "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", "Usage: ~ n - Bitwise not") @@ -206,7 +208,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") - val query = ctx.sql( + val query = sql( """ |SELECT | MIN(c1), @@ -230,7 +232,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") - ctx.sql( + sql( """ |CREATE TABLE with_table1 AS |WITH T AS ( @@ -240,27 +242,27 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { |SELECT * |FROM T """.stripMargin) - val query = ctx.sql("SELECT * FROM with_table1") + val query = sql("SELECT * FROM with_table1") checkAnswer(query, Row(1, 1) :: Nil) } test("explode nested Field") { Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") checkAnswer( - ctx.sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), + sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), Row(1) :: Row(2) :: Row(3) :: Nil) } test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { checkAnswer( - ctx.sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), - ctx.sql("SELECT key + key as a FROM src ORDER BY a").collect().toSeq + sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), + sql("SELECT key + key as a FROM src ORDER BY a").collect().toSeq ) } test("CTAS without serde") { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { - val relation = EliminateSubQueries(ctx.catalog.lookupRelation(Seq(tableName))) + val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) relation match { case LogicalRelation(r: ParquetRelation) => if (!isDataSourceParquet) { @@ -278,90 +280,89 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { } } - val originalConf = ctx.convertCTAS + val originalConf = convertCTAS - ctx.setConf(HiveContext.CONVERT_CTAS, true) + setConf(HiveContext.CONVERT_CTAS, true) try { - ctx.sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - ctx.sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") var message = intercept[AnalysisException] { - ctx.sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") }.getMessage assert(message.contains("ctas1 already exists")) checkRelation("ctas1", true) - ctx.sql("DROP TABLE ctas1") + sql("DROP TABLE ctas1") // Specifying database name for query can be converted to data source write path // is not allowed right now. message = intercept[AnalysisException] { - ctx.sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") }.getMessage assert( message.contains("Cannot specify database name in a CTAS statement"), "When spark.sql.hive.convertCTAS is true, we should not allow " + "database name specified.") - ctx.sql("CREATE TABLE ctas1 stored as textfile" + + sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", true) - ctx.sql("DROP TABLE ctas1") + sql("DROP TABLE ctas1") - ctx.sql("CREATE TABLE ctas1 stored as sequencefile" + + sql("CREATE TABLE ctas1 stored as sequencefile" + " AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", true) - ctx.sql("DROP TABLE ctas1") + sql("DROP TABLE ctas1") - ctx.sql( - "CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", false) - ctx.sql("DROP TABLE ctas1") + sql("DROP TABLE ctas1") - ctx.sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", false) - ctx.sql("DROP TABLE ctas1") - ctx.sql( - "CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") checkRelation("ctas1", false) - ctx.sql("DROP TABLE ctas1") + sql("DROP TABLE ctas1") } finally { - ctx.setConf(HiveContext.CONVERT_CTAS, originalConf) - ctx.sql("DROP TABLE IF EXISTS ctas1") + setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE IF EXISTS ctas1") } } test("SQL Dialect Switching") { - assert(ctx.getSQLDialect().getClass === classOf[HiveQLDialect]) - ctx.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) - assert(ctx.getSQLDialect().getClass === classOf[MyDialect]) - assert(ctx.sql("SELECT 1").collect() === Array(Row(1))) + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) + assert(getSQLDialect().getClass === classOf[MyDialect]) + assert(sql("SELECT 1").collect() === Array(Row(1))) // set the dialect back to the DefaultSQLDialect - ctx.sql("SET spark.sql.dialect=sql") - assert(ctx.getSQLDialect().getClass === classOf[DefaultParserDialect]) - ctx.sql("SET spark.sql.dialect=hiveql") - assert(ctx.getSQLDialect().getClass === classOf[HiveQLDialect]) + sql("SET spark.sql.dialect=sql") + assert(getSQLDialect().getClass === classOf[DefaultParserDialect]) + sql("SET spark.sql.dialect=hiveql") + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) // set invalid dialect - ctx.sql("SET spark.sql.dialect.abc=MyTestClass") - ctx.sql("SET spark.sql.dialect=abc") + sql("SET spark.sql.dialect.abc=MyTestClass") + sql("SET spark.sql.dialect=abc") intercept[Exception] { - ctx.sql("SELECT 1") + sql("SELECT 1") } // test if the dialect set back to HiveQLDialect - ctx.getSQLDialect().getClass === classOf[HiveQLDialect] + getSQLDialect().getClass === classOf[HiveQLDialect] - ctx.sql("SET spark.sql.dialect=MyTestClass") + sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { - ctx.sql("SELECT 1") + sql("SELECT 1") } // test if the dialect set back to HiveQLDialect - assert(ctx.getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) } test("CTAS with serde") { - ctx.sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() - ctx.sql( + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() + sql( """CREATE TABLE ctas2 | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") @@ -371,7 +372,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { | SELECT key, value | FROM src | ORDER BY key, value""".stripMargin).collect() - ctx.sql( + sql( """CREATE TABLE ctas3 | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' | STORED AS textfile AS @@ -380,41 +381,41 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { | ORDER BY key, value""".stripMargin).collect() // the table schema may like (key: integer, value: string) - ctx.sql( + sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect() // do nothing cause the table ctas4 already existed. - ctx.sql( + sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() checkAnswer( - ctx.sql("SELECT k, value FROM ctas1 ORDER BY k, value"), - ctx.sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql("SELECT k, value FROM ctas1 ORDER BY k, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) checkAnswer( - ctx.sql("SELECT key, value FROM ctas2 ORDER BY key, value"), - ctx.sql( + sql("SELECT key, value FROM ctas2 ORDER BY key, value"), + sql( """ SELECT key, value FROM src ORDER BY key, value""").collect().toSeq) checkAnswer( - ctx.sql("SELECT key, value FROM ctas3 ORDER BY key, value"), - ctx.sql( + sql("SELECT key, value FROM ctas3 ORDER BY key, value"), + sql( """ SELECT key, value FROM src ORDER BY key, value""").collect().toSeq) intercept[AnalysisException] { - ctx.sql( + sql( """CREATE TABLE ctas4 AS | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() } checkAnswer( - ctx.sql("SELECT key, value FROM ctas4 ORDER BY key, value"), - ctx.sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) + sql("SELECT key, value FROM ctas4 ORDER BY key, value"), + sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) - checkExistence(ctx.sql("DESC EXTENDED ctas2"), true, + checkExistence(sql("DESC EXTENDED ctas2"), true, "name:key", "type:string", "name:value", "ctas2", "org.apache.hadoop.hive.ql.io.RCFileInputFormat", "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", @@ -422,7 +423,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - ctx.sql( + sql( """CREATE TABLE ctas5 | STORED AS parquet AS | SELECT key, value @@ -430,7 +431,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { | ORDER BY key, value""".stripMargin).collect() withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { - checkExistence(ctx.sql("DESC EXTENDED ctas5"), true, + checkExistence(sql("DESC EXTENDED ctas5"), true, "name:key", "type:string", "name:value", "ctas5", "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", @@ -442,57 +443,57 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { // use the Hive SerDe for parquet tables withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { checkAnswer( - ctx.sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - ctx.sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) } } test("specifying the column list for CTAS") { Seq((1, "111111"), (2, "222222")).toDF("key", "value").registerTempTable("mytable1") - ctx.sql("create table gen__tmp(a int, b string) as select key, value from mytable1") + sql("create table gen__tmp(a int, b string) as select key, value from mytable1") checkAnswer( - ctx.sql("SELECT a, b from gen__tmp"), - ctx.sql("select key, value from mytable1").collect()) - ctx.sql("DROP TABLE gen__tmp") + sql("SELECT a, b from gen__tmp"), + sql("select key, value from mytable1").collect()) + sql("DROP TABLE gen__tmp") - ctx.sql("create table gen__tmp(a double, b double) as select key, value from mytable1") + sql("create table gen__tmp(a double, b double) as select key, value from mytable1") checkAnswer( - ctx.sql("SELECT a, b from gen__tmp"), - ctx.sql("select cast(key as double), cast(value as double) from mytable1").collect()) - ctx.sql("DROP TABLE gen__tmp") + sql("SELECT a, b from gen__tmp"), + sql("select cast(key as double), cast(value as double) from mytable1").collect()) + sql("DROP TABLE gen__tmp") - ctx.sql("drop table mytable1") + sql("drop table mytable1") } test("command substitution") { - ctx.sql("set tbl=src") + sql("set tbl=src") checkAnswer( - ctx.sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"), - ctx.sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) - ctx.sql("set hive.variable.substitute=false") // disable the substitution - ctx.sql("set tbl2=src") + sql("set hive.variable.substitute=false") // disable the substitution + sql("set tbl2=src") intercept[Exception] { - ctx.sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect() + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect() } - ctx.sql("set hive.variable.substitute=true") // enable the substitution + sql("set hive.variable.substitute=true") // enable the substitution checkAnswer( - ctx.sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"), - ctx.sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) } test("ordering not in select") { checkAnswer( - ctx.sql("SELECT key FROM src ORDER BY value"), - ctx.sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq) + sql("SELECT key FROM src ORDER BY value"), + sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq) } test("ordering not in agg") { checkAnswer( - ctx.sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), - ctx.sql(""" + sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), + sql(""" SELECT key FROM ( SELECT key, value @@ -502,103 +503,103 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { } test("double nested data") { - ctx.sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) + sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) .toDF().registerTempTable("nested") checkAnswer( - ctx.sql("SELECT f1.f2.f3 FROM nested"), + sql("SELECT f1.f2.f3 FROM nested"), Row(1)) - checkAnswer(ctx.sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), + checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), Seq.empty[Row]) checkAnswer( - ctx.sql("SELECT * FROM test_ctas_1234"), - ctx.sql("SELECT * FROM nested").collect().toSeq) + sql("SELECT * FROM test_ctas_1234"), + sql("SELECT * FROM nested").collect().toSeq) intercept[AnalysisException] { - ctx.sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() + sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() } } test("test CTAS") { - checkAnswer(ctx.sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) + checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) checkAnswer( - ctx.sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), - ctx.sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) + sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), + sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) } test("SPARK-4825 save join to table") { - val testData = ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() - ctx.sql("CREATE TABLE test1 (key INT, value STRING)") + val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() + sql("CREATE TABLE test1 (key INT, value STRING)") testData.write.mode(SaveMode.Append).insertInto("test1") - ctx.sql("CREATE TABLE test2 (key INT, value STRING)") + sql("CREATE TABLE test2 (key INT, value STRING)") testData.write.mode(SaveMode.Append).insertInto("test2") testData.write.mode(SaveMode.Append).insertInto("test2") - ctx.sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") + sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") checkAnswer( - ctx.table("test"), - ctx.sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq) + table("test"), + sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq) } test("SPARK-3708 Backticks aren't handled correctly is aliases") { checkAnswer( - ctx.sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), - ctx.sql("SELECT `key` FROM src").collect().toSeq) + sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), + sql("SELECT `key` FROM src").collect().toSeq) } test("SPARK-3834 Backticks not correctly handled in subquery aliases") { checkAnswer( - ctx.sql("SELECT a.key FROM (SELECT key FROM src) `a`"), - ctx.sql("SELECT `key` FROM src").collect().toSeq) + sql("SELECT a.key FROM (SELECT key FROM src) `a`"), + sql("SELECT `key` FROM src").collect().toSeq) } test("SPARK-3814 Support Bitwise & operator") { checkAnswer( - ctx.sql("SELECT case when 1&1=1 then 1 else 0 end FROM src"), - ctx.sql("SELECT 1 FROM src").collect().toSeq) + sql("SELECT case when 1&1=1 then 1 else 0 end FROM src"), + sql("SELECT 1 FROM src").collect().toSeq) } test("SPARK-3814 Support Bitwise | operator") { checkAnswer( - ctx.sql("SELECT case when 1|0=1 then 1 else 0 end FROM src"), - ctx.sql("SELECT 1 FROM src").collect().toSeq) + sql("SELECT case when 1|0=1 then 1 else 0 end FROM src"), + sql("SELECT 1 FROM src").collect().toSeq) } test("SPARK-3814 Support Bitwise ^ operator") { checkAnswer( - ctx.sql("SELECT case when 1^0=1 then 1 else 0 end FROM src"), - ctx.sql("SELECT 1 FROM src").collect().toSeq) + sql("SELECT case when 1^0=1 then 1 else 0 end FROM src"), + sql("SELECT 1 FROM src").collect().toSeq) } test("SPARK-3814 Support Bitwise ~ operator") { checkAnswer( - ctx.sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"), - ctx.sql("SELECT 1 FROM src").collect().toSeq) + sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"), + sql("SELECT 1 FROM src").collect().toSeq) } test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { - checkAnswer(ctx.sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), - ctx.sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), + sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) } test("SPARK-2554 SumDistinct partial aggregation") { - checkAnswer(ctx.sql("SELECT sum( distinct key) FROM src group by key order by key"), - ctx.sql("SELECT distinct key FROM src order by key").collect().toSeq) + checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), + sql("SELECT distinct key FROM src order by key").collect().toSeq) } test("SPARK-4963 DataFrame sample on mutable row return wrong result") { - ctx.sql("SELECT * FROM src WHERE key % 2 = 0") + sql("SELECT * FROM src WHERE key % 2 = 0") .sample(withReplacement = false, fraction = 0.3) .registerTempTable("sampled") (1 to 10).foreach { i => checkAnswer( - ctx.sql("SELECT * FROM sampled WHERE key % 2 = 1"), + sql("SELECT * FROM sampled WHERE key % 2 = 1"), Seq.empty[Row]) } } test("SPARK-4699 HiveContext should be case insensitive by default") { checkAnswer( - ctx.sql("SELECT KEY FROM Src ORDER BY value"), - ctx.sql("SELECT key FROM src ORDER BY value").collect().toSeq) + sql("SELECT KEY FROM Src ORDER BY value"), + sql("SELECT key FROM src ORDER BY value").collect().toSeq) } test("SPARK-5284 Insert into Hive throws NPE when a inner complex type field has a null value") { @@ -610,76 +611,74 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { StructField("innerMap", MapType(StringType, IntegerType)) :: Nil), true) :: Nil) val row = Row(Row(null, null, null)) - val rowRdd = ctx.sparkContext.parallelize(row :: Nil) + val rowRdd = sparkContext.parallelize(row :: Nil) - ctx.createDataFrame(rowRdd, schema).registerTempTable("testTable") + TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") - ctx.sql( + sql( """CREATE TABLE nullValuesInInnerComplexTypes | (s struct, | innerArray:array, | innerMap: map>) """.stripMargin).collect() - ctx.sql( + sql( """ |INSERT OVERWRITE TABLE nullValuesInInnerComplexTypes |SELECT * FROM testTable """.stripMargin) checkAnswer( - ctx.sql("SELECT * FROM nullValuesInInnerComplexTypes"), + sql("SELECT * FROM nullValuesInInnerComplexTypes"), Row(Row(null, null, null)) ) - ctx.sql("DROP TABLE nullValuesInInnerComplexTypes") - ctx.dropTempTable("testTable") + sql("DROP TABLE nullValuesInInnerComplexTypes") + dropTempTable("testTable") } test("SPARK-4296 Grouping field with Hive UDF as sub expression") { - val rdd = ctx.sparkContext.makeRDD( - """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) - ctx.read.json(rdd).registerTempTable("data") + val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) + read.json(rdd).registerTempTable("data") checkAnswer( - ctx.sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), + sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), Row("str-1", 1970)) - ctx.dropTempTable("data") + dropTempTable("data") - ctx.read.json(rdd).registerTempTable("data") - checkAnswer(ctx.sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) + read.json(rdd).registerTempTable("data") + checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) - ctx.dropTempTable("data") + dropTempTable("data") } test("resolve udtf in projection #1") { - val rdd = ctx.sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - ctx.read.json(rdd).registerTempTable("data") - val df = ctx.sql("SELECT explode(a) AS val FROM data") + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) + read.json(rdd).registerTempTable("data") + val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") } test("resolve udtf in projection #2") { - val rdd = ctx.sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) - ctx.read.json(rdd).registerTempTable("data") - checkAnswer(ctx.sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) - checkAnswer(ctx.sql( - "SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) + val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) + read.json(rdd).registerTempTable("data") + checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) + checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) intercept[AnalysisException] { - ctx.sql("SELECT explode(map(1, 1)) as k1 FROM data LIMIT 1") + sql("SELECT explode(map(1, 1)) as k1 FROM data LIMIT 1") } intercept[AnalysisException] { - ctx.sql("SELECT explode(map(1, 1)) as (k1, k2, k3) FROM data LIMIT 1") + sql("SELECT explode(map(1, 1)) as (k1, k2, k3) FROM data LIMIT 1") } } // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive test("TGF with non-TGF in projection") { - val rdd = ctx.sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) - ctx.read.json(rdd).registerTempTable("data") + val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) + read.json(rdd).registerTempTable("data") checkAnswer( - ctx.sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), + sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), Row("1", "1", "1", "1") :: Nil) } @@ -690,40 +689,40 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. - val rdd = ctx.sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - ctx.read.json(rdd).registerTempTable("data") - val originalConf = ctx.convertCTAS - ctx.setConf(HiveContext.CONVERT_CTAS, false) + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) + read.json(rdd).registerTempTable("data") + val originalConf = convertCTAS + setConf(HiveContext.CONVERT_CTAS, false) try { - ctx.sql("CREATE TABLE explodeTest (key bigInt)") - ctx.table("explodeTest").queryExecution.analyzed match { + sql("CREATE TABLE explodeTest (key bigInt)") + table("explodeTest").queryExecution.analyzed match { case metastoreRelation: MetastoreRelation => // OK case _ => fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") } - ctx.sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") + sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") checkAnswer( - ctx.sql("SELECT key from explodeTest"), + sql("SELECT key from explodeTest"), (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) ) - ctx.sql("DROP TABLE explodeTest") - ctx.dropTempTable("data") + sql("DROP TABLE explodeTest") + dropTempTable("data") } finally { - ctx.setConf(HiveContext.CONVERT_CTAS, originalConf) + setConf(HiveContext.CONVERT_CTAS, originalConf) } } test("sanity test for SPARK-6618") { (1 to 100).par.map { i => val tableName = s"SPARK_6618_table_$i" - ctx.sql(s"CREATE TABLE $tableName (col1 string)") - ctx.catalog.lookupRelation(Seq(tableName)) - ctx.table(tableName) - ctx.tables() - ctx.sql(s"DROP TABLE $tableName") + sql(s"CREATE TABLE $tableName (col1 string)") + catalog.lookupRelation(Seq(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") } } @@ -733,7 +732,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { .select($"d1".cast(DecimalType(10, 5)).as("d")) .registerTempTable("dn") - ctx.sql("select d from dn union all select d * 2 from dn") + sql("select d from dn union all select d * 2 from dn") .queryExecution.analyzed } @@ -741,7 +740,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(100000 === - ctx.sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") .queryExecution.toRdd.count()) } @@ -749,7 +748,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === - ctx.sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") .queryExecution.toRdd.count()) } @@ -757,7 +756,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { val data = (1 to 5).map { i => (i, i) } data.toDF("key", "value").registerTempTable("test") checkAnswer( - ctx.sql("""FROM + sql("""FROM |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t |SELECT thing1 + 1 """.stripMargin), (2 to 6).map(i => Row(i))) @@ -772,10 +771,10 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { WindowData(5, "c", 9), WindowData(6, "c", 10) ) - ctx.sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().registerTempTable("windowData") checkAnswer( - ctx.sql( + sql( """ |select area, sum(product), sum(sum(product)) over (partition by area) |from windowData group by month, area @@ -790,7 +789,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { ).map(i => Row(i._1, i._2, i._3))) checkAnswer( - ctx.sql( + sql( """ |select area, sum(product) - 1, sum(sum(product)) over (partition by area) |from windowData group by month, area @@ -805,7 +804,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { ).map(i => Row(i._1, i._2, i._3))) checkAnswer( - ctx.sql( + sql( """ |select area, sum(product), sum(product) / sum(sum(product)) over (partition by area) |from windowData group by month, area @@ -820,7 +819,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { ).map(i => Row(i._1, i._2, i._3))) checkAnswer( - ctx.sql( + sql( """ |select area, sum(product), sum(product) / sum(sum(product) - 1) over (partition by area) |from windowData group by month, area @@ -844,10 +843,10 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { WindowData(5, "c", 9), WindowData(6, "c", 10) ) - ctx.sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().registerTempTable("windowData") checkAnswer( - ctx.sql( + sql( """ |select month, area, product, sum(product + 1) over (partition by 1 order by 2) |from windowData @@ -862,7 +861,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { ).map(i => Row(i._1, i._2, i._3, i._4))) checkAnswer( - ctx.sql( + sql( """ |select month, area, product, sum(product) |over (partition by month % 2 order by 10 - product) @@ -887,10 +886,10 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { WindowData(5, "c", 9), WindowData(6, "c", 10) ) - ctx.sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().registerTempTable("windowData") checkAnswer( - ctx.sql( + sql( """ |select month, area, month % 2, |lag(product, 1 + 1, product) over (partition by month % 2 order by area) @@ -907,7 +906,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { } test("window function: multiple window expressions in a single expression") { - val nums = ctx.sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") nums.registerTempTable("nums") val expected = @@ -922,7 +921,7 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { Row(1, 9, 45, 55, 25, 125) :: Row(0, 10, 55, 55, 30, 140) :: Nil - val actual = ctx.sql( + val actual = sql( """ |SELECT | y, @@ -939,20 +938,20 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { checkAnswer(actual, expected) - ctx.dropTempTable("nums") + dropTempTable("nums") } test("test case key when") { (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t") checkAnswer( - ctx.sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"), + sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"), Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil) } test("SPARK-7595: Window will cause resolve failed with self join") { - ctx.sql("SELECT * FROM src") // Force loading of src table. + sql("SELECT * FROM src") // Force loading of src table. - checkAnswer(ctx.sql( + checkAnswer(sql( """ |with | v1 as (select key, count(value) over (partition by key) cnt_val from src), @@ -965,27 +964,27 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { Seq(1, 2, 3).map { i => (i.toString, i.toString) }.toDF("key", "value").registerTempTable("df_analysis") - ctx.sql("SELECT kEy from df_analysis group by key").collect() - ctx.sql("SELECT kEy+3 from df_analysis group by key+3").collect() - ctx.sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() - ctx.sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect() - ctx.sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect() - ctx.sql("SELECT 2 from df_analysis A group by key+1").collect() + sql("SELECT kEy from df_analysis group by key").collect() + sql("SELECT kEy+3 from df_analysis group by key+3").collect() + sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect() + sql("SELECT 2 from df_analysis A group by key+1").collect() intercept[AnalysisException] { - ctx.sql("SELECT kEy+1 from df_analysis group by key+3") + sql("SELECT kEy+1 from df_analysis group by key+3") } intercept[AnalysisException] { - ctx.sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)") + sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)") } } test("Cast STRING to BIGINT") { - checkAnswer(ctx.sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L)) + checkAnswer(sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L)) } // `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest test("udf_java_method") { - checkAnswer(ctx.sql( + checkAnswer(sql( """ |SELECT java_method("java.lang.String", "valueOf", 1), | java_method("java.lang.String", "isEmpty"), @@ -1008,34 +1007,34 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { test("dynamic partition value test") { try { - ctx.sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("set hive.exec.dynamic.partition.mode=nonstrict") // date - ctx.sql("drop table if exists dynparttest1") - ctx.sql("create table dynparttest1 (value int) partitioned by (pdate date)") - ctx.sql( + sql("drop table if exists dynparttest1") + sql("create table dynparttest1 (value int) partitioned by (pdate date)") + sql( """ |insert into table dynparttest1 partition(pdate) | select count(*), cast('2015-05-21' as date) as pdate from src """.stripMargin) checkAnswer( - ctx.sql("select * from dynparttest1"), + sql("select * from dynparttest1"), Seq(Row(500, java.sql.Date.valueOf("2015-05-21")))) // decimal - ctx.sql("drop table if exists dynparttest2") - ctx.sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))") - ctx.sql( + sql("drop table if exists dynparttest2") + sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))") + sql( """ |insert into table dynparttest2 partition(pdec) | select count(*), cast('100.12' as decimal(5, 1)) as pdec from src """.stripMargin) checkAnswer( - ctx.sql("select * from dynparttest2"), + sql("select * from dynparttest2"), Seq(Row(500, new java.math.BigDecimal("100.1")))) } finally { - ctx.sql("drop table if exists dynparttest1") - ctx.sql("drop table if exists dynparttest2") - ctx.sql("set hive.exec.dynamic.partition.mode=strict") + sql("drop table if exists dynparttest1") + sql("drop table if exists dynparttest2") + sql("set hive.exec.dynamic.partition.mode=strict") } } @@ -1044,10 +1043,10 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { val thread = new Thread { override def run() { // To make sure this test works, this jar should not be loaded in another place. - val jar = TestHiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - ctx.sql(s"ADD JAR $jar") + TestHive.sql( + s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") try { - ctx.sql( + TestHive.sql( """ |CREATE TEMPORARY FUNCTION example_max |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' @@ -1069,14 +1068,14 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { test("SPARK-6785: HiveQuerySuite - Date comparison test 2") { checkAnswer( - ctx.sql("SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1"), + sql("SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1"), Row(false)) } test("SPARK-6785: HiveQuerySuite - Date cast") { // new Date(0) == 1970-01-01 00:00:00.0 GMT == 1969-12-31 16:00:00.0 PST checkAnswer( - ctx.sql( + sql( """ | SELECT | CAST(CAST(0 AS timestamp) AS date), @@ -1096,44 +1095,45 @@ class SQLQuerySuite extends QueryTest with SharedHiveContext { } test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { - val df = ctx.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) + val df = + TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) df.toDF("id", "datef").registerTempTable("test_SPARK8588") checkAnswer( - ctx.sql( + TestHive.sql( """ |select id, concat(year(datef)) |from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year') """.stripMargin), Row(1, "2014") :: Row(2, "2015") :: Nil ) - ctx.dropTempTable("test_SPARK8588") + TestHive.dropTempTable("test_SPARK8588") } test("SPARK-9371: fix the support for special chars in column names for hive context") { - ctx.read.json(ctx.sparkContext.makeRDD( + TestHive.read.json(TestHive.sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") - checkAnswer(ctx.sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } test("Convert hive interval term into Literal of CalendarIntervalType") { - checkAnswer(ctx.sql("select interval '10-9' year to month"), + checkAnswer(sql("select interval '10-9' year to month"), Row(CalendarInterval.fromString("interval 10 years 9 months"))) - checkAnswer(ctx.sql("select interval '20 15:40:32.99899999' day to second"), + checkAnswer(sql("select interval '20 15:40:32.99899999' day to second"), Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes " + "32 seconds 99 milliseconds 899 microseconds"))) - checkAnswer(ctx.sql("select interval '30' year"), + checkAnswer(sql("select interval '30' year"), Row(CalendarInterval.fromString("interval 30 years"))) - checkAnswer(ctx.sql("select interval '25' month"), + checkAnswer(sql("select interval '25' month"), Row(CalendarInterval.fromString("interval 25 months"))) - checkAnswer(ctx.sql("select interval '-100' day"), + checkAnswer(sql("select interval '-100' day"), Row(CalendarInterval.fromString("interval -14 weeks -2 days"))) - checkAnswer(ctx.sql("select interval '40' hour"), + checkAnswer(sql("select interval '40' hour"), Row(CalendarInterval.fromString("interval 1 days 16 hours"))) - checkAnswer(ctx.sql("select interval '80' minute"), + checkAnswer(sql("select interval '80' minute"), Row(CalendarInterval.fromString("interval 1 hour 20 minutes"))) - checkAnswer(ctx.sql("select interval '299.889987299' second"), + checkAnswer(sql("select interval '299.889987299' second"), Row(CalendarInterval.fromString( "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 53678ce73302..0875232aede3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -22,13 +22,16 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends SparkPlanTest with SharedHiveContext { +class ScriptTransformationSuite extends SparkPlanTest { + + override def sqlContext: SQLContext = TestHive private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, @@ -55,7 +58,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SharedHiveContext { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = noSerdeIOSchema - )(ctx), + )(TestHive), rowsDf.collect()) } @@ -69,7 +72,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SharedHiveContext { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = serdeIOSchema - )(ctx), + )(TestHive), rowsDf.collect()) } @@ -84,7 +87,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SharedHiveContext { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = noSerdeIOSchema - )(ctx), + )(TestHive), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) @@ -101,7 +104,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SharedHiveContext { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = serdeIOSchema - )(ctx), + )(TestHive), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 7da1242a283f..deec0048d24b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { - import testImplicits._ - override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + import sqlContext._ + import sqlContext.implicits._ + test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) @@ -36,7 +37,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - ctx.sparkContext + sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") .write @@ -47,7 +48,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - ctx.read.options(Map( + read.options(Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index dfbbc21539a6..a46ca9a2c970 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -18,15 +18,19 @@ package org.apache.spark.sql.hive.orc import java.io.File +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.util.Utils +import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.hadoop.hive.conf.HiveConf.ConfVars - -import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.SharedHiveContext - // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -34,23 +38,27 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { - import testImplicits._ - +class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal + def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { data.toDF().write.mode("overwrite").orc(path.getCanonicalPath) } + def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { df.write.mode("overwrite").orc(path.getCanonicalPath) } protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally ctx.dropTempTable(tableName) + try f finally TestHive.dropTempTable(tableName) } protected def makePartitionDir( @@ -81,11 +89,11 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - ctx.read.orc(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( - ctx.sql("SELECT * FROM t"), + sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -93,7 +101,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, i.toString, pi, ps)) checkAnswer( - ctx.sql("SELECT intField, pi FROM t"), + sql("SELECT intField, pi FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -101,14 +109,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, pi)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE pi = 1"), + sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") } yield Row(i, i.toString, 1, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE ps = 'foo'"), + sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -128,11 +136,11 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - ctx.read.orc(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( - ctx.sql("SELECT * FROM t"), + sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -140,7 +148,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, pi, i.toString, ps)) checkAnswer( - ctx.sql("SELECT intField, pi FROM t"), + sql("SELECT intField, pi FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -148,14 +156,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, pi)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE pi = 1"), + sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") } yield Row(i, 1, i.toString, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE ps = 'foo'"), + sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -177,14 +185,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - ctx.read + read .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { checkAnswer( - ctx.sql("SELECT * FROM t"), + sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) @@ -192,14 +200,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, i.toString, pi, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE pi IS NULL"), + sql("SELECT * FROM t WHERE pi IS NULL"), for { i <- 1 to 10 ps <- Seq("foo", null.asInstanceOf[String]) } yield Row(i, i.toString, null, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE ps IS NULL"), + sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) @@ -219,14 +227,14 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - ctx.read + read .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { checkAnswer( - ctx.sql("SELECT * FROM t"), + sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -234,7 +242,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with SharedHiveContext { } yield Row(i, pi, i.toString, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE ps IS NULL"), + sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, 2) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 31dd2ef96d52..744d46293814 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -21,8 +21,11 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind +import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -45,8 +48,7 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with OrcTest { - import testImplicits._ +class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { def getTempFilePath(prefix: String, suffix: String = ""): File = { val tempFile = File.createTempFile(prefix, suffix) @@ -61,14 +63,14 @@ class OrcQuerySuite extends QueryTest with OrcTest { withOrcFile(data) { file => checkAnswer( - ctx.read.orc(file), + sqlContext.read.orc(file), data.toDF().collect()) } } test("Read/write binary data") { withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => - val bytes = ctx.read.orc(file).head().getAs[Array[Byte]](0) + val bytes = read.orc(file).head().getAs[Array[Byte]](0) assert(new String(bytes, "utf8") === "test") } } @@ -86,16 +88,16 @@ class OrcQuerySuite extends QueryTest with OrcTest { withOrcFile(data) { file => checkAnswer( - ctx.read.orc(file), + read.orc(file), data.toDF().collect()) } } test("Creating case class RDD table") { val data = (1 to 100).map(i => (i, s"val_$i")) - ctx.sparkContext.parallelize(data).toDF().registerTempTable("t") + sparkContext.parallelize(data).toDF().registerTempTable("t") withTempTable("t") { - checkAnswer(ctx.sql("SELECT * FROM t"), data.toDF().collect()) + checkAnswer(sql("SELECT * FROM t"), data.toDF().collect()) } } @@ -108,13 +110,13 @@ class OrcQuerySuite extends QueryTest with OrcTest { // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // expr = leaf-0 - assert(ctx.sql("SELECT name FROM t WHERE age <= 5").count() === 5) + assert(sql("SELECT name FROM t WHERE age <= 5").count() === 5) // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // expr = (not leaf-0) assertResult(10) { - ctx.sql("SELECT name, contacts FROM t where age > 5") + sql("SELECT name, contacts FROM t where age > 5") .flatMap(_.getAs[Seq[_]]("contacts")) .count() } @@ -124,7 +126,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { // leaf-1 = (LESS_THAN age 8) // expr = (and (not leaf-0) leaf-1) { - val df = ctx.sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") + val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") assert(df.count() === 2) assertResult(4) { df.flatMap(_.getAs[Seq[_]]("contacts")).count() @@ -136,7 +138,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { // leaf-1 = (LESS_THAN_EQUALS age 8) // expr = (or leaf-0 (not leaf-1)) { - val df = ctx.sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") + val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") assert(df.count() === 3) assertResult(6) { df.flatMap(_.getAs[Seq[_]]("contacts")).count() @@ -156,7 +158,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { withOrcFile(data) { file => checkAnswer( - ctx.read.orc(file), + read.orc(file), Row(Seq.fill(5)(null): _*)) } } @@ -173,7 +175,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { // Following codec is supported in hive-0.13.1, ignore it now ignore("Other compression options for writing to an ORC file - 0.13.1 and above") { val data = (1 to 100).map(i => (i, s"val_$i")) - val conf = ctx.sparkContext.hadoopConfiguration + val conf = sparkContext.hadoopConfiguration conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY") withOrcFile(data) { file => @@ -200,33 +202,33 @@ class OrcQuerySuite extends QueryTest with OrcTest { test("simple select queries") { withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { checkAnswer( - ctx.sql("SELECT `_1` FROM t where t.`_1` > 5"), + sql("SELECT `_1` FROM t where t.`_1` > 5"), (6 until 10).map(Row.apply(_))) checkAnswer( - ctx.sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), + sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), (0 until 5).map(Row.apply(_))) } } test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withOrcTable(data, "t") { - ctx.sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } - ctx.catalog.unregisterTable(Seq("tmp")) + catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withOrcTable(data, "t") { - ctx.sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), data.map(Row.fromTuple)) } - ctx.catalog.unregisterTable(Seq("tmp")) + catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -237,7 +239,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { } withOrcTable(data, "t") { - val selfJoin = ctx.sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) @@ -252,7 +254,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { test("nested data - struct with array field") { val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) withOrcTable(data, "t") { - checkAnswer(ctx.sql("SELECT `_1`.`_2`[0] FROM t"), data.map { + checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) }) } @@ -261,7 +263,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { test("nested data - array of struct") { val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) withOrcTable(data, "t") { - checkAnswer(ctx.sql("SELECT `_1`[0].`_2` FROM t"), data.map { + checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) }) } @@ -269,18 +271,18 @@ class OrcQuerySuite extends QueryTest with OrcTest { test("columns only referenced by pushed down filters should remain") { withOrcTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(ctx.sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) + checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) } } test("SPARK-5309 strings stored using dictionary compression in orc") { withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") { checkAnswer( - ctx.sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), (0 until 10).map(i => Row("same", "run_" + i, 100))) checkAnswer( - ctx.sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), List(Row("same", "run_5", 100))) } } @@ -291,7 +293,7 @@ class OrcQuerySuite extends QueryTest with OrcTest { withTable("empty_orc") { withTempTable("empty", "single") { - ctx.sql( + sqlContext.sql( s"""CREATE TABLE empty_orc(key INT, value STRING) |STORED AS ORC |LOCATION '$path' @@ -302,13 +304,13 @@ class OrcQuerySuite extends QueryTest with OrcTest { // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because // Spark SQL ORC data source always avoids write empty ORC files. - ctx.sql( + sqlContext.sql( s"""INSERT INTO TABLE empty_orc |SELECT key, value FROM empty """.stripMargin) val errorMessage = intercept[AnalysisException] { - ctx.read.orc(path) + sqlContext.read.orc(path) }.getMessage assert(errorMessage.contains("Failed to discover schema from ORC files")) @@ -316,12 +318,12 @@ class OrcQuerySuite extends QueryTest with OrcTest { val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) singleRowDF.registerTempTable("single") - ctx.sql( + sqlContext.sql( s"""INSERT INTO TABLE empty_orc |SELECT key, value FROM single """.stripMargin) - val df = ctx.read.orc(path) + val df = sqlContext.read.orc(path) assert(df.schema === singleRowDF.schema.asNullable) checkAnswer(df, singleRowDF) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 641aafb24884..82e08caf4645 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.hive.orc import java.io.File +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.SharedHiveContext case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with SharedHiveContext { - import testImplicits._ - +abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { var orcTableDir: File = null var orcTableAsDir: File = null @@ -41,14 +41,15 @@ abstract class OrcSuite extends QueryTest with SharedHiveContext { orcTableDir = File.createTempFile("orctests", "sparksql") orcTableDir.delete() orcTableDir.mkdir() + import org.apache.spark.sql.hive.test.TestHive.implicits._ - ctx.sparkContext + sparkContext .makeRDD(1 to 10) .map(i => OrcData(i, s"part-$i")) .toDF() .registerTempTable(s"orc_temp_table") - ctx.sql( + sql( s"""CREATE EXTERNAL TABLE normal_orc( | intField INT, | stringField STRING @@ -57,81 +58,76 @@ abstract class OrcSuite extends QueryTest with SharedHiveContext { |LOCATION '${orcTableAsDir.getCanonicalPath}' """.stripMargin) - ctx.sql( + sql( s"""INSERT INTO TABLE normal_orc |SELECT intField, stringField FROM orc_temp_table """.stripMargin) } override def afterAll(): Unit = { - try { - orcTableDir.delete() - orcTableAsDir.delete() - } finally { - super.afterAll() - } + orcTableDir.delete() + orcTableAsDir.delete() } test("create temporary orc table") { - checkAnswer(ctx.sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) checkAnswer( - ctx.sql("SELECT * FROM normal_orc_source"), + sql("SELECT * FROM normal_orc_source"), (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - ctx.sql("SELECT * FROM normal_orc_source where intField > 5"), + sql("SELECT * FROM normal_orc_source where intField > 5"), (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - ctx.sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), (1 to 10).map(i => Row(1, s"part-$i"))) } test("create temporary orc table as") { - checkAnswer(ctx.sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) checkAnswer( - ctx.sql("SELECT * FROM normal_orc_source"), + sql("SELECT * FROM normal_orc_source"), (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - ctx.sql("SELECT * FROM normal_orc_source WHERE intField > 5"), + sql("SELECT * FROM normal_orc_source WHERE intField > 5"), (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - ctx.sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), (1 to 10).map(i => Row(1, s"part-$i"))) } test("appending insert") { - ctx.sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5") + sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5") checkAnswer( - ctx.sql("SELECT * FROM normal_orc_source"), + sql("SELECT * FROM normal_orc_source"), (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => Seq.fill(2)(Row(i, s"part-$i")) }) } test("overwrite insert") { - ctx.sql( + sql( """INSERT OVERWRITE TABLE normal_orc_as_source |SELECT * FROM orc_temp_table WHERE intField > 5 """.stripMargin) checkAnswer( - ctx.sql("SELECT * FROM normal_orc_as_source"), + sql("SELECT * FROM normal_orc_as_source"), (6 to 10).map(i => Row(i, s"part-$i"))) } } class OrcSourceSuite extends OrcSuite { - override def beforeAll(): Unit = { super.beforeAll() - ctx.sql( + sql( s"""CREATE TEMPORARY TABLE normal_orc_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( @@ -139,7 +135,7 @@ class OrcSourceSuite extends OrcSuite { |) """.stripMargin) - ctx.sql( + sql( s"""CREATE TEMPORARY TABLE normal_orc_as_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index d974011c4699..145965388da0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -24,10 +24,13 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.test.SQLTestUtils -private[sql] trait OrcTest extends SparkFunSuite with SharedHiveContext { - import testImplicits._ +private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => + lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive + + import sqlContext.implicits._ + import sqlContext.sparkContext /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` @@ -37,7 +40,7 @@ private[sql] trait OrcTest extends SparkFunSuite with SharedHiveContext { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - ctx.sparkContext.parallelize(data).toDF().write.orc(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().write.orc(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -49,7 +52,7 @@ private[sql] trait OrcTest extends SparkFunSuite with SharedHiveContext { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(ctx.read.orc(path))) + withOrcFile(data)(path => f(sqlContext.read.orc(path))) } /** @@ -61,7 +64,7 @@ private[sql] trait OrcTest extends SparkFunSuite with SharedHiveContext { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => - ctx.registerDataFrameAsTable(df, tableName) + sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 65c7387c12a1..50f02432dacc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -19,12 +19,17 @@ package org.apache.spark.sql.hive import java.io.File +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -53,7 +58,6 @@ case class ParquetDataWithKeyAndComplexTypes( * built in parquet support. */ class ParquetMetastoreSuite extends ParquetPartitioningTest { - override def beforeAll(): Unit = { super.beforeAll() dropTables("partitioned_parquet", @@ -64,7 +68,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { "jt", "jt_array", "test_parquet") - ctx.sql(s""" + sql(s""" create external table partitioned_parquet ( intField INT, @@ -78,7 +82,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { location '${partitionedTableDir.getCanonicalPath}' """) - ctx.sql(s""" + sql(s""" create external table partitioned_parquet_with_key ( intField INT, @@ -92,7 +96,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { location '${partitionedTableDirWithKey.getCanonicalPath}' """) - ctx.sql(s""" + sql(s""" create external table normal_parquet ( intField INT, @@ -105,7 +109,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { location '${new File(normalTableDir, "normal").getCanonicalPath}' """) - ctx.sql(s""" + sql(s""" CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes ( intField INT, @@ -121,7 +125,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { LOCATION '${partitionedTableDirWithComplexTypes.getCanonicalPath}' """) - ctx.sql(s""" + sql(s""" CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes ( intField INT, @@ -137,7 +141,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' """) - ctx.sql( + sql( """ |create table test_parquet |( @@ -151,67 +155,63 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) (1 to 10).foreach { p => - ctx.sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") } (1 to 10).foreach { p => - ctx.sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") + sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") } (1 to 10).foreach { p => - ctx.sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)") + sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)") } (1 to 10).foreach { p => - ctx.sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)") + sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)") } - val rdd1 = ctx.sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - ctx.read.json(rdd1).registerTempTable("jt") - val rdd2 = ctx.sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) - ctx.read.json(rdd2).registerTempTable("jt_array") + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + read.json(rdd1).registerTempTable("jt") + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) + read.json(rdd2).registerTempTable("jt_array") - ctx.setConf(HiveContext.CONVERT_METASTORE_PARQUET, true) + setConf(HiveContext.CONVERT_METASTORE_PARQUET, true) } override def afterAll(): Unit = { - try { - dropTables("partitioned_parquet", - "partitioned_parquet_with_key", - "partitioned_parquet_with_complextypes", - "partitioned_parquet_with_key_and_complextypes", - "normal_parquet", - "jt", - "jt_array", - "test_parquet") - ctx.setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) - } finally { - super.afterAll() - } + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") + setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } test(s"conversion is working") { assert( - ctx.sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { case _: HiveTableScan => true }.isEmpty) assert( - ctx.sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { case _: PhysicalRDD => true }.nonEmpty) } test("scan an empty parquet table") { - checkAnswer(ctx.sql("SELECT count(*) FROM test_parquet"), Row(0)) + checkAnswer(sql("SELECT count(*) FROM test_parquet"), Row(0)) } test("scan an empty parquet table with upper case") { - checkAnswer(ctx.sql("SELECT count(INTFIELD) FROM TEST_parquet"), Row(0)) + checkAnswer(sql("SELECT count(INTFIELD) FROM TEST_parquet"), Row(0)) } test("insert into an empty parquet table") { dropTables("test_insert_parquet") - ctx.sql( + sql( """ |create table test_insert_parquet |( @@ -225,21 +225,21 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // Insert into am empty table. - ctx.sql("insert into table test_insert_parquet select a, b from jt where jt.a > 5") + sql("insert into table test_insert_parquet select a, b from jt where jt.a > 5") checkAnswer( - ctx.sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField < 8"), + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField < 8"), Row(6, "str6") :: Row(7, "str7") :: Nil ) // Insert overwrite. - ctx.sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") checkAnswer( - ctx.sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), Row(3, "str3") :: Row(4, "str4") :: Nil ) dropTables("test_insert_parquet") // Create it again. - ctx.sql( + sql( """ |create table test_insert_parquet |( @@ -252,15 +252,15 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) // Insert overwrite an empty table. - ctx.sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") checkAnswer( - ctx.sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), Row(3, "str3") :: Row(4, "str4") :: Nil ) // Insert into the table. - ctx.sql("insert into table test_insert_parquet select a, b from jt") + sql("insert into table test_insert_parquet select a, b from jt") checkAnswer( - ctx.sql(s"SELECT intField, stringField FROM test_insert_parquet"), + sql(s"SELECT intField, stringField FROM test_insert_parquet"), (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i")) ) dropTables("test_insert_parquet") @@ -268,7 +268,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("scan a parquet table created through a CTAS statement") { withTable("test_parquet_ctas") { - ctx.sql( + sql( """ |create table test_parquet_ctas ROW FORMAT |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' @@ -279,11 +279,11 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) checkAnswer( - ctx.sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), + sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), Seq(Row(1, "str1")) ) - ctx.table("test_parquet_ctas").queryExecution.optimizedPlan match { + table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(_: ParquetRelation) => // OK case _ => fail( "test_parquet_ctas should be converted to " + @@ -294,7 +294,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("MetastoreRelation in InsertIntoTable will be converted") { withTable("test_insert_parquet") { - ctx.sql( + sql( """ |create table test_insert_parquet |( @@ -306,7 +306,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - val df = ctx.sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.executedPlan match { case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + @@ -316,15 +316,15 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } checkAnswer( - ctx.sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), - ctx.sql("SELECT a FROM jt WHERE jt.a > 5").collect() + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() ) } } test("MetastoreRelation in InsertIntoHiveTable will be converted") { withTable("test_insert_parquet") { - ctx.sql( + sql( """ |create table test_insert_parquet |( @@ -336,7 +336,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - val df = ctx.sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.executedPlan match { case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + @@ -346,15 +346,15 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } checkAnswer( - ctx.sql("SELECT int_array FROM test_insert_parquet"), - ctx.sql("SELECT a FROM jt_array").collect() + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() ) } } test("SPARK-6450 regression test") { withTable("ms_convert") { - ctx.sql( + sql( """CREATE TABLE IF NOT EXISTS ms_convert (key INT) |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' |STORED AS @@ -363,7 +363,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // This shouldn't throw AnalysisException - val analyzed = ctx.sql( + val analyzed = sql( """SELECT key FROM ms_convert |UNION ALL |SELECT key FROM ms_convert @@ -388,7 +388,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { withTable("nonPartitioned") { - ctx.sql( + sql( s"""CREATE TABLE nonPartitioned ( | key INT, | value STRING @@ -397,9 +397,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // First lookup fills the cache - val r1 = collectParquetRelation(ctx.table("nonPartitioned")) + val r1 = collectParquetRelation(table("nonPartitioned")) // Second lookup should reuse the cache - val r2 = collectParquetRelation(ctx.table("nonPartitioned")) + val r2 = collectParquetRelation(table("nonPartitioned")) // They should be the same instance assert(r1 eq r2) } @@ -407,7 +407,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { withTable("partitioned") { - ctx.sql( + sql( s"""CREATE TABLE partitioned ( | key INT, | value STRING @@ -417,19 +417,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // First lookup fills the cache - val r1 = collectParquetRelation(ctx.table("partitioned")) + val r1 = collectParquetRelation(table("partitioned")) // Second lookup should reuse the cache - val r2 = collectParquetRelation(ctx.table("partitioned")) + val r2 = collectParquetRelation(table("partitioned")) // They should be the same instance assert(r1 eq r2) } } test("Caching converted data source Parquet Relations") { - val _ctx = ctx - def checkCached(tableIdentifier: _ctx.catalog.QualifiedTableName): Unit = { + def checkCached(tableIdentifier: catalog.QualifiedTableName): Unit = { // Converted test_parquet should be cached. - ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { + catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // OK case other => @@ -441,7 +440,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") - ctx.sql( + sql( """ |create table test_insert_parquet |( @@ -454,18 +453,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - var tableIdentifier = _ctx.catalog.QualifiedTableName("default", "test_insert_parquet") + var tableIdentifier = catalog.QualifiedTableName("default", "test_insert_parquet") // First, make sure the converted test_parquet is not cached. - assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Table lookup will make the table cached. - ctx.table("test_insert_parquet") + table("test_insert_parquet") checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. - ctx.invalidateTable("test_insert_parquet") - assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - ctx.sql( + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + sql( """ |INSERT INTO TABLE test_insert_parquet |select a, b from jt @@ -473,14 +472,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( - ctx.sql("select * from test_insert_parquet"), - ctx.sql("select a, b from jt").collect()) + sql("select * from test_insert_parquet"), + sql("select a, b from jt").collect()) // Invalidate the cache. - ctx.invalidateTable("test_insert_parquet") - assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Create a partitioned table. - ctx.sql( + sql( """ |create table test_parquet_partitioned_cache_test |( @@ -494,10 +493,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - tableIdentifier = _ctx.catalog.QualifiedTableName( - "default", "test_parquet_partitioned_cache_test") - assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - ctx.sql( + tableIdentifier = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (`date`='2015-04-01') @@ -505,30 +503,30 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. - assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - ctx.sql( + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) - assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Make sure we can cache the partitioned table. - ctx.table("test_parquet_partitioned_cache_test") + table("test_parquet_partitioned_cache_test") checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( - ctx.sql("select STRINGField, `date`, intField from test_parquet_partitioned_cache_test"), - ctx.sql( + sql("select STRINGField, `date`, intField from test_parquet_partitioned_cache_test"), + sql( """ |select b, '2015-04-01', a FROM jt |UNION ALL |select b, '2015-04-02', a FROM jt """.stripMargin).collect()) - ctx.invalidateTable("test_parquet_partitioned_cache_test") - assert(ctx.catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + invalidateTable("test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } @@ -538,8 +536,6 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { * A suite of tests for the Parquet support through the data sources API. */ class ParquetSourceSuite extends ParquetPartitioningTest { - import testImplicits._ - override def beforeAll(): Unit = { super.beforeAll() dropTables("partitioned_parquet", @@ -548,7 +544,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { "partitioned_parquet_with_key_and_complextypes", "normal_parquet") - ctx.sql( s""" + sql( s""" create temporary table partitioned_parquet USING org.apache.spark.sql.parquet OPTIONS ( @@ -556,7 +552,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { ) """) - ctx.sql( s""" + sql( s""" create temporary table partitioned_parquet_with_key USING org.apache.spark.sql.parquet OPTIONS ( @@ -564,7 +560,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { ) """) - ctx.sql( s""" + sql( s""" create temporary table normal_parquet USING org.apache.spark.sql.parquet OPTIONS ( @@ -572,7 +568,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { ) """) - ctx.sql( s""" + sql( s""" CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes USING org.apache.spark.sql.parquet OPTIONS ( @@ -580,7 +576,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { ) """) - ctx.sql( s""" + sql( s""" CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes USING org.apache.spark.sql.parquet OPTIONS ( @@ -590,29 +586,29 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } test("SPARK-6016 make sure to use the latest footers") { - ctx.sql("drop table if exists spark_6016_fix") + sql("drop table if exists spark_6016_fix") // Create a DataFrame with two partitions. So, the created table will have two parquet files. - val df1 = ctx.read.json(ctx.sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) + val df1 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") checkAnswer( - ctx.sql("select * from spark_6016_fix"), + sql("select * from spark_6016_fix"), (1 to 10).map(i => Row(i)) ) // Create a DataFrame with four partitions. So, the created table will have four parquet files. - val df2 = ctx.read.json(ctx.sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) + val df2 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) df2.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, // since the new table has four parquet files, we are trying to read new footers from two files // and then merge metadata in footers of these four (two outdated ones and two latest one), // which will cause an error. checkAnswer( - ctx.sql("select * from spark_6016_fix"), + sql("select * from spark_6016_fix"), (1 to 10).map(i => Row(i)) ) - ctx.sql("drop table spark_6016_fix") + sql("drop table spark_6016_fix") } test("SPARK-8811: compatibility with array of struct in Hive") { @@ -626,14 +622,14 @@ class ParquetSourceSuite extends ParquetPartitioningTest { SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "true") withSQLConf(conf: _*) { - ctx.sql( + sql( s"""CREATE TABLE array_of_struct |STORED AS PARQUET LOCATION '$path' |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) """.stripMargin) checkAnswer( - ctx.read.parquet(path), + sqlContext.read.parquet(path), Row("1st", "2nd", Seq(Row("val_a", "val_b")))) } } @@ -641,7 +637,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } test("values in arrays and maps stored in parquet are always nullable") { - val df = ctx.createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") + val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = false) val arrayType1 = ArrayType(IntegerType, containsNull = false) val expectedSchema1 = @@ -660,10 +656,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest { StructField("m", mapType2, nullable = true) :: StructField("a", arrayType2, nullable = true) :: Nil) - assert(ctx.table("alwaysNullable").schema === expectedSchema2) + assert(table("alwaysNullable").schema === expectedSchema2) checkAnswer( - ctx.sql("SELECT m, a FROM alwaysNullable"), + sql("SELECT m, a FROM alwaysNullable"), Row(Map(2 -> 3), Seq(4, 5, 6))) } } @@ -679,7 +675,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { val df3 = df2.toDF("str", "max_int") df3.write.parquet(filePath2) - val df4 = ctx.read.parquet(filePath2) + val df4 = read.parquet(filePath2) checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) assert(df4.columns === Array("str", "max_int")) } @@ -688,8 +684,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest { /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext { - import testImplicits._ +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + override def sqlContext: SQLContext = TestHive var partitionedTableDir: File = null var normalTableDir: File = null @@ -698,19 +694,18 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext var partitionedTableDirWithKeyAndComplexTypes: File = null override def beforeAll(): Unit = { - super.beforeAll() partitionedTableDir = Utils.createTempDir() normalTableDir = Utils.createTempDir() (1 to 10).foreach { p => val partDir = new File(partitionedTableDir, s"p=$p") - ctx.sparkContext.makeRDD(1 to 10) + sparkContext.makeRDD(1 to 10) .map(i => ParquetData(i, s"part-$p")) .toDF() .write.parquet(partDir.getCanonicalPath) } - ctx.sparkContext + sparkContext .makeRDD(1 to 10) .map(i => ParquetData(i, s"part-1")) .toDF() @@ -720,7 +715,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext (1 to 10).foreach { p => val partDir = new File(partitionedTableDirWithKey, s"p=$p") - ctx.sparkContext.makeRDD(1 to 10) + sparkContext.makeRDD(1 to 10) .map(i => ParquetDataWithKey(p, i, s"part-$p")) .toDF() .write.parquet(partDir.getCanonicalPath) @@ -730,7 +725,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext (1 to 10).foreach { p => val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p") - ctx.sparkContext.makeRDD(1 to 10).map { i => + sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithKeyAndComplexTypes( p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) }.toDF().write.parquet(partDir.getCanonicalPath) @@ -740,22 +735,18 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext (1 to 10).foreach { p => val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") - ctx.sparkContext.makeRDD(1 to 10).map { i => + sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) }.toDF().write.parquet(partDir.getCanonicalPath) } } override protected def afterAll(): Unit = { - try { - partitionedTableDir.delete() - normalTableDir.delete() - partitionedTableDirWithKey.delete() - partitionedTableDirWithComplexTypes.delete() - partitionedTableDirWithKeyAndComplexTypes.delete() - } finally { - super.afterAll() - } + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() } /** @@ -764,7 +755,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext */ def dropTables(tableNames: String*): Unit = { tableNames.foreach { name => - ctx.sql(s"DROP TABLE IF EXISTS $name") + sql(s"DROP TABLE IF EXISTS $name") } } @@ -776,19 +767,19 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext test(s"ordering of the partitioning columns $table") { checkAnswer( - ctx.sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), Seq.fill(10)(Row(1, "part-1")) ) checkAnswer( - ctx.sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), Seq.fill(10)(Row("part-1", 1)) ) } test(s"project the partitioning column $table") { checkAnswer( - ctx.sql(s"SELECT p, count(*) FROM $table group by p"), + sql(s"SELECT p, count(*) FROM $table group by p"), Row(1, 10) :: Row(2, 10) :: Row(3, 10) :: @@ -804,7 +795,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext test(s"project partitioning and non-partitioning columns $table") { checkAnswer( - ctx.sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), Row("part-1", 1, 10) :: Row("part-2", 2, 10) :: Row("part-3", 3, 10) :: @@ -820,44 +811,44 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext test(s"simple count $table") { checkAnswer( - ctx.sql(s"SELECT COUNT(*) FROM $table"), + sql(s"SELECT COUNT(*) FROM $table"), Row(100)) } test(s"pruned count $table") { checkAnswer( - ctx.sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), Row(10)) } test(s"non-existent partition $table") { checkAnswer( - ctx.sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), Row(0)) } test(s"multi-partition pruned count $table") { checkAnswer( - ctx.sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), Row(30)) } test(s"non-partition predicates $table") { checkAnswer( - ctx.sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), Row(30)) } test(s"sum $table") { checkAnswer( - ctx.sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), Row(1 + 2 + 3)) } test(s"hive udfs $table") { checkAnswer( - ctx.sql(s"SELECT concat(stringField, stringField) FROM $table"), - ctx.sql(s"SELECT stringField FROM $table").map { + sql(s"SELECT concat(stringField, stringField) FROM $table"), + sql(s"SELECT stringField FROM $table").map { case Row(s: String) => Row(s + s) }.collect().toSeq) } @@ -869,7 +860,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext test(s"SPARK-5775 read struct from $table") { checkAnswer( - ctx.sql( + sql( s""" |SELECT p, structField.intStructField, structField.stringStructField |FROM $table WHERE p = 1 @@ -880,7 +871,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext // Re-enable this after SPARK-5508 is fixed ignore(s"SPARK-5775 read array from $table") { checkAnswer( - ctx.sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), + sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), (1 to 10).map(i => Row(1 to i, 1))) } } @@ -888,7 +879,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SharedHiveContext test("non-part select(*)") { checkAnswer( - ctx.sql("SELECT COUNT(*) FROM normal_parquet"), + sql("SELECT COUNT(*) FROM normal_parquet"), Row(10)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index 6ab7bc4fde90..e976125b3706 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -20,9 +20,13 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils + + +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { + override val sqlContext = TestHive -class CommitFailureTestRelationSuite extends SparkFunSuite with SharedHiveContext { // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName @@ -31,7 +35,7 @@ class CommitFailureTestRelationSuite extends SparkFunSuite with SharedHiveContex // Here we coalesce partition number to 1 to ensure that only a single task is issued. This // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` // directory while committing/aborting the job. See SPARK-8513 for more details. - val df = ctx.range(0, 10).coalesce(1) + val df = sqlContext.range(0, 10).coalesce(1) intercept[SparkException] { df.write.format(dataSourceName).save(file.getCanonicalPath) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 0a823bcaa178..ed6d512ab36f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" + import sqlContext._ + test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) @@ -34,7 +36,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - ctx.sparkContext + sparkContext .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") .saveAsTextFile(partitionDir.toString) } @@ -43,7 +45,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - ctx.read.format(dataSourceName) + read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -61,14 +63,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { val data = Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil - val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) + val df = createDataFrame(sparkContext.parallelize(data), schema) // Write the data out. df.write.format(dataSourceName).save(file.getCanonicalPath) // Read it back and check the result. checkAnswer( - ctx.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), df ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 0d9c9b42a8ab..cb4cedddbfdd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -28,10 +28,11 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - import testImplicits._ - override val dataSourceName: String = "parquet" + import sqlContext._ + import sqlContext.implicits._ + test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) @@ -40,7 +41,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - ctx.sparkContext + sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") .write.parquet(partitionDir.toString) @@ -50,7 +51,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - ctx.read.format(dataSourceName) + read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -68,7 +69,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { .format("parquet") .save(s"${dir.getCanonicalPath}/_temporary") - checkAnswer(ctx.read.format("parquet").load(dir.getCanonicalPath), df.collect()) + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) } } @@ -96,7 +97,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // This shouldn't throw anything. df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(ctx.read.format("parquet").load(path), df) + checkAnswer(read.format("parquet").load(path), df) } } @@ -106,7 +107,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // Parquet doesn't allow field names with spaces. Here we are intentionally making an // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger // the bug. Please refer to spark-8079 for more details. - ctx.range(1, 10) + range(1, 10) .withColumnRenamed("id", "a b") .write .format("parquet") @@ -118,7 +119,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("SPARK-8604: Parquet data source should write summary file while doing appending") { withTempPath { dir => val path = dir.getCanonicalPath - val df = ctx.range(0, 5) + val df = sqlContext.range(0, 5) df.write.mode(SaveMode.Overwrite).parquet(path) val summaryPath = new Path(path, "_metadata") @@ -129,7 +130,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { fs.delete(commonSummaryPath, true) df.write.mode(SaveMode.Append).parquet(path) - checkAnswer(ctx.read.parquet(path), df.unionAll(df)) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) assert(fs.exists(summaryPath)) assert(fs.exists(commonSummaryPath)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index f6e47814d8cc..e8975e5f5cd0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -23,9 +23,10 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + import sqlContext._ + test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) @@ -34,7 +35,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - ctx.sparkContext + sparkContext .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") .saveAsTextFile(partitionDir.toString) } @@ -43,7 +44,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - ctx.read.format(dataSourceName) + read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 37abeade74e4..2a69d331b6e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -28,12 +28,16 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.SharedHiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { - import testImplicits._ +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { + override lazy val sqlContext: SQLContext = TestHive + + import sqlContext.sql + import sqlContext.implicits._ val dataSourceName: String @@ -88,7 +92,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { df.registerTempTable("t") withTempTable("t") { checkAnswer( - ctx.sql( + sql( """SELECT l.a, r.b, l.p1, r.p2 |FROM t l JOIN t r |ON l.a = r.a AND l.p1 = r.p1 AND l.p2 = r.p2 @@ -103,7 +107,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - ctx.read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("path", file.getCanonicalPath) .option("dataSchema", dataSchema.json) .load(), @@ -117,7 +121,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - ctx.read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath).orderBy("a"), testDF.unionAll(testDF).orderBy("a").collect()) @@ -137,7 +141,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { testDF.write.mode(SaveMode.Ignore).format(dataSourceName).save(file.getCanonicalPath) val path = new Path(file.getCanonicalPath) - val fs = path.getFileSystem(ctx.sparkContext.hadoopConfiguration) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) assert(fs.listStatus(path).isEmpty) } } @@ -151,7 +155,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .save(file.getCanonicalPath) checkQueries( - ctx.read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath)) } @@ -172,7 +176,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .save(file.getCanonicalPath) checkAnswer( - ctx.read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -194,7 +198,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .save(file.getCanonicalPath) checkAnswer( - ctx.read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.unionAll(partitionedTestDF).collect()) @@ -216,7 +220,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .save(file.getCanonicalPath) checkAnswer( - ctx.read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -252,7 +256,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .saveAsTable("t") withTable("t") { - checkAnswer(ctx.table("t"), testDF.collect()) + checkAnswer(sqlContext.table("t"), testDF.collect()) } } @@ -261,7 +265,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { - checkAnswer(ctx.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) } } @@ -280,7 +284,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { withTempTable("t") { testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") - assert(ctx.table("t").collect().isEmpty) + assert(sqlContext.table("t").collect().isEmpty) } } @@ -291,7 +295,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .saveAsTable("t") withTable("t") { - checkQueries(ctx.table("t")) + checkQueries(sqlContext.table("t")) } } @@ -311,7 +315,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .saveAsTable("t") withTable("t") { - checkAnswer(ctx.table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } @@ -331,7 +335,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .saveAsTable("t") withTable("t") { - checkAnswer(ctx.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) } } @@ -351,7 +355,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .saveAsTable("t") withTable("t") { - checkAnswer(ctx.table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } @@ -400,7 +404,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .partitionBy("p1", "p2") .saveAsTable("t") - assert(ctx.table("t").collect().isEmpty) + assert(sqlContext.table("t").collect().isEmpty) } } @@ -412,7 +416,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .partitionBy("p1", "p2") .save(file.getCanonicalPath) - val df = ctx.read + val df = sqlContext.read .format(dataSourceName) .option("dataSchema", dataSchema.json) .load(s"${file.getCanonicalPath}/p1=*/p2=???") @@ -424,7 +428,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { s"${file.getCanonicalFile}/p1=2/p2=bar" ).map { p => val path = new Path(p) - val fs = path.getFileSystem(ctx.sparkContext.hadoopConfiguration) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) path.makeQualified(fs.getUri, fs.getWorkingDirectory).toString } @@ -454,7 +458,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .saveAsTable("t") withTempTable("t") { - checkAnswer(ctx.table("t"), input.collect()) + checkAnswer(sqlContext.table("t"), input.collect()) } } } @@ -469,7 +473,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .saveAsTable("t") withTable("t") { - checkAnswer(ctx.table("t"), df.select('b, 'c, 'a).collect()) + checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) } } @@ -481,7 +485,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { test("SPARK-8406: Avoids name collision while writing files") { withTempPath { dir => val path = dir.getCanonicalPath - ctx + sqlContext .range(10000) .repartition(250) .write @@ -490,7 +494,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { .save(path) assertResult(10000) { - ctx + sqlContext .read .format(dataSourceName) .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) @@ -503,7 +507,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { test("SPARK-8578 specified custom output committer will not be used to append data") { val clonedConf = new Configuration(configuration) try { - val df = ctx.range(1, 10).toDF("i") + val df = sqlContext.range(1, 10).toDF("i") withTempPath { dir => df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) configuration.set( @@ -518,7 +522,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SharedHiveContext { // with file format and AlwaysFailOutputCommitter will not be used. df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) checkAnswer( - ctx.read + sqlContext.read .format(dataSourceName) .option("dataSchema", df.schema.json) .load(dir.getCanonicalPath), From d85a6d8bc8fbc7dfd15400a07931723b1f334912 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 13 Aug 2015 13:10:22 -0700 Subject: [PATCH 38/39] Add a shorthand for sqlContext.sql / ctx.sql Per @marmbrus request. --- project/SparkBuild.scala | 4 +- .../apache/spark/sql/CachedTableSuite.scala | 58 +-- .../spark/sql/ColumnExpressionSuite.scala | 12 +- .../spark/sql/DataFrameFunctionsSuite.scala | 10 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 2 +- .../apache/spark/sql/DateFunctionsSuite.scala | 6 +- .../org/apache/spark/sql/JoinSuite.scala | 32 +- .../apache/spark/sql/ListTablesSuite.scala | 8 +- .../spark/sql/MathExpressionsSuite.scala | 20 +- .../org/apache/spark/sql/SQLConfSuite.scala | 14 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 430 +++++++++--------- .../sql/ScalaReflectionRelationSuite.scala | 10 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 26 +- .../spark/sql/UserDefinedTypeSuite.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 22 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 8 +- .../datasources/json/JsonSuite.scala | 127 +++--- .../ParquetPartitionDiscoverySuite.scala | 28 +- .../parquet/ParquetQuerySuite.scala | 21 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 77 ++-- .../spark/sql/jdbc/JDBCWriteSuite.scala | 10 +- .../sources/CreateTableAsSelectSuite.scala | 39 +- .../spark/sql/sources/DDLTestSuite.scala | 5 +- .../spark/sql/sources/FilteredScanSuite.scala | 5 +- .../spark/sql/sources/InsertSuite.scala | 89 ++-- .../spark/sql/sources/PrunedScanSuite.scala | 5 +- .../spark/sql/sources/SaveLoadSuite.scala | 3 +- .../spark/sql/sources/TableScanSuite.scala | 42 +- .../apache/spark/sql/test/SQLTestData.scala | 44 +- .../apache/spark/sql/test/SQLTestUtils.scala | 3 + .../sql/hive/HiveMetastoreCatalogSuite.scala | 4 +- .../org/apache/spark/sql/hive/UDFSuite.scala | 6 +- 33 files changed, 587 insertions(+), 587 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 109722210884..04e0d49b178c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -366,8 +366,8 @@ object Hive { |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ - |import org.apache.spark.sql.types._ - """.stripMargin, + |import org.apache.spark.sql.hive.test.TestHive.implicits._ + |import org.apache.spark.sql.types._""".stripMargin, cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce // in order to generate golden files. This is only required for developers who are adding new diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index b412af9d5112..af7590c3d3c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -68,9 +68,9 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("cache temp table") { testData.select('key).registerTempTable("tempTable") - assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable"), 0) + assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) ctx.cacheTable("tempTable") - assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable")) + assertCached(sql("SELECT COUNT(*) FROM tempTable")) ctx.uncacheTable("tempTable") } @@ -89,8 +89,8 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("cache table as select") { - ctx.sql("CACHE TABLE tempTable AS SELECT key FROM testData") - assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable")) + sql("CACHE TABLE tempTable AS SELECT key FROM testData") + assertCached(sql("SELECT COUNT(*) FROM tempTable")) ctx.uncacheTable("tempTable") } @@ -99,14 +99,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { testData.select('key).registerTempTable("tempTable2") ctx.cacheTable("tempTable1") - assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable1")) - assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable2")) + assertCached(sql("SELECT COUNT(*) FROM tempTable1")) + assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? ctx.uncacheTable("tempTable2") // Should this be cached? - assertCached(ctx.sql("SELECT COUNT(*) FROM tempTable1"), 0) + assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) } test("too big for memory") { @@ -187,26 +187,26 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("SELECT star from cached table") { - ctx.sql("SELECT * FROM testData").registerTempTable("selectStar") + sql("SELECT * FROM testData").registerTempTable("selectStar") ctx.cacheTable("selectStar") checkAnswer( - ctx.sql("SELECT * FROM selectStar WHERE key = 1"), + sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) ctx.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = - ctx.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() + sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() ctx.cacheTable("testData") checkAnswer( - ctx.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), + sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) ctx.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") @@ -214,7 +214,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { @@ -223,7 +223,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { - ctx.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") @@ -238,7 +238,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("CACHE TABLE tableName AS SELECT ...") { - ctx.sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") + sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") @@ -253,7 +253,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("CACHE LAZY TABLE tableName") { - ctx.sql("CACHE LAZY TABLE testData") + sql("CACHE LAZY TABLE testData") assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") @@ -261,7 +261,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { !isMaterialized(rddId), "Lazily cached in-memory table shouldn't be materialized eagerly") - ctx.sql("SELECT COUNT(*) FROM testData").collect() + sql("SELECT COUNT(*) FROM testData").collect() assert( isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") @@ -273,7 +273,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("InMemoryRelation statistics") { - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") ctx.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum @@ -302,32 +302,32 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("Clear all cache") { - ctx.sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") - ctx.sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") ctx.cacheTable("t1") ctx.cacheTable("t2") ctx.clearCache() assert(ctx.cacheManager.isEmpty) - ctx.sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") - ctx.sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") ctx.cacheTable("t1") ctx.cacheTable("t2") - ctx.sql("Clear CACHE") + sql("Clear CACHE") assert(ctx.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { - ctx.sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") - ctx.sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") ctx.cacheTable("t1") ctx.cacheTable("t2") - ctx.sql("SELECT * FROM t1").count() - ctx.sql("SELECT * FROM t2").count() - ctx.sql("SELECT * FROM t1").count() - ctx.sql("SELECT * FROM t2").count() + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() Accumulators.synchronized { val accsSize = Accumulators.originals.size diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 053c0f052d9c..ee74e3e83da5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -262,7 +262,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) checkAnswer( - ctx.sql("select isnull(null), isnull(1)"), + sql("select isnull(null), isnull(1)"), Row(true, false)) } @@ -272,7 +272,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) checkAnswer( - ctx.sql("select isnotnull(null), isnotnull('a')"), + sql("select isnotnull(null), isnotnull('a')"), Row(false, true)) } @@ -293,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( - ctx.sql("select isnan(15), isnan('invalid')"), + sql("select isnan(15), isnan('invalid')"), Row(false, false)) } @@ -313,7 +313,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) testData.registerTempTable("t") checkAnswer( - ctx.sql( + sql( "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + " nanvl(b, e), nanvl(e, f) from t"), Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) @@ -520,7 +520,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) checkAnswer( - ctx.sql("SELECT upper('aB'), ucase('cDe')"), + sql("SELECT upper('aB'), ucase('cDe')"), Row("AB", "CDE")) } @@ -541,7 +541,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) checkAnswer( - ctx.sql("SELECT lower('aB'), lcase('cDe')"), + sql("SELECT lower('aB'), lcase('cDe')"), Row("ab", "cde")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7d6ef5a1f085..9d965258e389 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -117,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("constant functions") { checkAnswer( - ctx.sql("SELECT E()"), + sql("SELECT E()"), Row(scala.math.E) ) checkAnswer( - ctx.sql("SELECT PI()"), + sql("SELECT PI()"), Row(scala.math.Pi) ) } @@ -151,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("nvl function") { checkAnswer( - ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), Row("x", "y", null)) } @@ -220,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(-1) ) checkAnswer( - ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), + sql("SELECT least(a, 2) as l from testData2 order by l"), Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) ) } @@ -231,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(3) ) checkAnswer( - ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), + sql("SELECT greatest(a, 2) as g from testData2 order by g"), Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 9d7cb2de67b1..e2716d7841d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -57,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") .collect().toSeq) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 8cfa9189ef07..9080c53c491a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -33,7 +33,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) val d2 = DateTimeUtils.fromJavaDate( - ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) } @@ -43,9 +43,9 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) // Execution in one query should return the same value - checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true)) - assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( 0).getTime - System.currentTimeMillis()) < 5000) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e52a6f96b921..f5c5046a8ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -36,7 +36,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = ctx.sql(sqlString) + val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -119,7 +119,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") for (sortMergeJoinEnabled <- Seq(true, false)) { withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { @@ -134,12 +134,12 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", @@ -160,7 +160,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { @@ -272,7 +272,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -286,7 +286,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(6, 1) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -332,7 +332,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -341,7 +341,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, 6)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -393,7 +393,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -402,7 +402,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, 10)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -417,7 +417,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -432,7 +432,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -443,7 +443,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { Seq( @@ -462,11 +462,11 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("left semi join") { - val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2a80cab0bc51..babf8835d254 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -41,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -54,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -66,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - ctx.sql( + sql( "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 455bf306d7b3..30289c3c1d09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -147,7 +147,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) checkAnswer( - ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + sql("SELECT degrees(0), degrees(1), degrees(1.5)"), Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) ) } @@ -155,7 +155,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) checkAnswer( - ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + sql("SELECT radians(0), radians(1), radians(1.5)"), Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) ) } @@ -167,7 +167,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { test("ceil and ceiling") { testOneToOneMathFunction(ceil, math.ceil) checkAnswer( - ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), Row(0.0, 1.0, 2.0)) } @@ -212,7 +212,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { val pi = 3.1415 checkAnswer( - ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) @@ -231,7 +231,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { testOneToOneMathFunction[Double](signum, math.signum) checkAnswer( - ctx.sql("SELECT sign(10), signum(-11)"), + sql("SELECT sign(10), signum(-11)"), Row(1, -1)) } @@ -239,7 +239,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { testTwoToOneMathFunction(pow, pow, math.pow) checkAnswer( - ctx.sql("SELECT pow(1, 2), power(2, 1)"), + sql("SELECT pow(1, 2), power(2, 1)"), Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) ) } @@ -278,7 +278,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { test("log / ln") { testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) checkAnswer( - ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + sql("SELECT ln(0), ln(1), ln(1.5)"), Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) ) } @@ -373,7 +373,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { df.select(log2("b") + log2("a")), Row(1)) - checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) } test("sqrt") { @@ -382,13 +382,13 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { df.select(sqrt("a"), sqrt("b")), Row(1.0, 2.0)) - checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) } test("negative") { checkAnswer( - ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + sql("SELECT negative(1), negative(0), negative(-1)"), Row(-1, 0, 1)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index fa2aabb4f2fc..7699adadd9cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -51,21 +51,21 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("parse SQL set commands") { ctx.conf.clear() - ctx.sql(s"set $testKey=$testVal") + sql(s"set $testKey=$testVal") assert(ctx.getConf(testKey, testVal + "_") === testVal) assert(ctx.getConf(testKey, testVal + "_") === testVal) - ctx.sql("set some.property=20") + sql("set some.property=20") assert(ctx.getConf("some.property", "0") === "20") - ctx.sql("set some.property = 40") + sql("set some.property = 40") assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - ctx.sql(s"set $key=$vs") + sql(s"set $key=$vs") assert(ctx.getConf(key, "0") === vs) - ctx.sql(s"set $key=") + sql(s"set $key=") assert(ctx.getConf(key, "0") === "") ctx.conf.clear() @@ -73,14 +73,14 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("deprecated property") { ctx.conf.clear() - ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") assert(ctx.conf.numShufflePartitions === 10) } test("invalid conf value") { ctx.conf.clear() val e = intercept[IllegalArgumentException] { - ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 312f6f008080..8c2c328f8191 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -40,27 +40,27 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") checkAnswer( - ctx.sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"), + sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"), Row("one", 6) :: Row("three", 3) :: Nil) } test("SPARK-8010: promote numeric to string") { val df = Seq((1, 1)).toDF("key", "value") df.registerTempTable("src") - val queryCaseWhen = ctx.sql("select case when true then 1.0 else '1' end from src ") - val queryCoalesce = ctx.sql("select coalesce(null, 1, '1') from src ") + val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ") + val queryCoalesce = sql("select coalesce(null, 1, '1') from src ") checkAnswer(queryCaseWhen, Row("1.0") :: Nil) checkAnswer(queryCoalesce, Row("1") :: Nil) } test("show functions") { - checkAnswer(ctx.sql("SHOW functions"), + checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) } test("describe functions") { - checkExistence(ctx.sql("describe function extended upper"), true, + checkExistence(sql("describe function extended upper"), true, "Function: upper", "Class: org.apache.spark.sql.catalyst.expressions.Upper", "Usage: upper(str) - Returns str with all characters changed to uppercase", @@ -68,15 +68,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "> SELECT upper('SparkSql');", "'SPARKSQL'") - checkExistence(ctx.sql("describe functioN Upper"), true, + checkExistence(sql("describe functioN Upper"), true, "Function: upper", "Class: org.apache.spark.sql.catalyst.expressions.Upper", "Usage: upper(str) - Returns str with all characters changed to uppercase") - checkExistence(ctx.sql("describe functioN Upper"), false, + checkExistence(sql("describe functioN Upper"), false, "Extended Usage") - checkExistence(ctx.sql("describe functioN abcadf"), true, + checkExistence(sql("describe functioN abcadf"), true, "Function: abcadf is not found.") } @@ -89,7 +89,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sqlContext.cacheTable("cachedData") checkAnswer( - ctx.sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), + sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) } @@ -97,7 +97,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") checkAnswer( - ctx.sql( + sql( """ |SELECT x.str, COUNT(*) |FROM df x JOIN df y ON x.str = y.str @@ -108,7 +108,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("support table.star") { checkAnswer( - ctx.sql( + sql( """ |SELECT r.* |FROM testData l join testData2 r on (l.key = r.a) @@ -125,7 +125,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .registerTempTable("df") checkAnswer( - ctx.sql( + sql( """ |SELECT x.str, SUM(x.strCount) |FROM df x JOIN df y ON x.str = y.str @@ -164,7 +164,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( - ctx.sql("SELECT a FROM testData2 SORT BY a"), + sql("SELECT a FROM testData2 SORT BY a"), Seq(1, 1, 2, 2, 3, 3).map(Row(_)) ) } @@ -200,7 +200,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .registerTempTable("rows") checkAnswer( - ctx.sql( + sql( """ |select attribute, sum(cnt) |from ( @@ -219,7 +219,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .registerTempTable("d") checkAnswer( - ctx.sql("select * from d where d.a in (1,2)"), + sql("select * from d where d.a in (1,2)"), Seq(Row("1"), Row("2"))) } @@ -227,13 +227,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { checkAnswer( - ctx.sql("select sum(a), avg(a) from allNulls"), + sql("select sum(a), avg(a) from allNulls"), Seq(Row(null, null)) ) } withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { checkAnswer( - ctx.sql("select sum(a), avg(a) from allNulls"), + sql("select sum(a), avg(a) from allNulls"), Seq(Row(null, null)) ) } @@ -241,13 +241,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { checkAnswer( - ctx.sql("select sum(a), avg(a) from allNulls"), + sql("select sum(a), avg(a) from allNulls"), Seq(Row(null, null)) ) } withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { checkAnswer( - ctx.sql("select sum(a), avg(a) from allNulls"), + sql("select sum(a), avg(a) from allNulls"), Seq(Row(null, null)) ) } @@ -255,7 +255,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { - val df = ctx.sql(sqlText) + val df = sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.executedPlan .collect { case _: aggregate.TungstenAggregate => true } @@ -357,82 +357,82 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Add Parser of SQL COALESCE()") { checkAnswer( - ctx.sql("""SELECT COALESCE(1, 2)"""), + sql("""SELECT COALESCE(1, 2)"""), Row(1)) checkAnswer( - ctx.sql("SELECT COALESCE(null, 1, 1.5)"), + sql("SELECT COALESCE(null, 1, 1.5)"), Row(BigDecimal(1))) checkAnswer( - ctx.sql("SELECT COALESCE(null, null, null)"), + sql("SELECT COALESCE(null, null, null)"), Row(null)) } test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( - ctx.sql("SELECT LAST(n) FROM lowerCaseData"), + sql("SELECT LAST(n) FROM lowerCaseData"), Row(4)) } test("SPARK-2041 column name equals tablename") { checkAnswer( - ctx.sql("SELECT tableName FROM tableName"), + sql("SELECT tableName FROM tableName"), Row("test")) } test("SQRT") { checkAnswer( - ctx.sql("SELECT SQRT(key) FROM testData"), + sql("SELECT SQRT(key) FROM testData"), (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq ) } test("SQRT with automatic string casts") { checkAnswer( - ctx.sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"), + sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"), (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq ) } test("SPARK-2407 Added Parser of SQL SUBSTR()") { checkAnswer( - ctx.sql("SELECT substr(tableName, 1, 2) FROM tableName"), + sql("SELECT substr(tableName, 1, 2) FROM tableName"), Row("te")) checkAnswer( - ctx.sql("SELECT substr(tableName, 3) FROM tableName"), + sql("SELECT substr(tableName, 3) FROM tableName"), Row("st")) checkAnswer( - ctx.sql("SELECT substring(tableName, 1, 2) FROM tableName"), + sql("SELECT substring(tableName, 1, 2) FROM tableName"), Row("te")) checkAnswer( - ctx.sql("SELECT substring(tableName, 3) FROM tableName"), + sql("SELECT substring(tableName, 3) FROM tableName"), Row("st")) } test("SPARK-3173 Timestamp support in the parser") { (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time").registerTempTable("timestamps") - checkAnswer(ctx.sql( + checkAnswer(sql( "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00"))) - checkAnswer(ctx.sql( + checkAnswer(sql( "SELECT time FROM timestamps WHERE time=CAST('1969-12-31 16:00:00.001' AS TIMESTAMP)"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) - checkAnswer(ctx.sql( + checkAnswer(sql( "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.001'"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) - checkAnswer(ctx.sql( + checkAnswer(sql( "SELECT time FROM timestamps WHERE '1969-12-31 16:00:00.001'=time"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) - checkAnswer(ctx.sql( + checkAnswer(sql( """SELECT time FROM timestamps WHERE time<'1969-12-31 16:00:00.003' AND time>'1969-12-31 16:00:00.001'"""), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))) - checkAnswer(ctx.sql( + checkAnswer(sql( """ |SELECT time FROM timestamps |WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002') @@ -440,41 +440,39 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))) - checkAnswer(ctx.sql( + checkAnswer(sql( "SELECT time FROM timestamps WHERE time='123'"), Nil) } test("index into array") { checkAnswer( - ctx.sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), + sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect()) } test("left semi greater than predicate") { checkAnswer( - ctx.sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), Seq(Row(3, 1), Row(3, 2)) ) } test("left semi greater than predicate and equal operator") { checkAnswer( - ctx.sql( - "SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"), + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"), Seq(Row(3, 1), Row(3, 2)) ) checkAnswer( - ctx.sql( - "SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"), + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"), Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)) ) } test("index into array of arrays") { checkAnswer( - ctx.sql( + sql( "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"), arrayData.map(d => Row(d.nestedData, @@ -484,28 +482,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("agg") { checkAnswer( - ctx.sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), + sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } test("literal in agg grouping expressions") { def literalInAggTest(): Unit = { checkAnswer( - ctx.sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( - ctx.sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( - ctx.sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - ctx.sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) checkAnswer( - ctx.sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - ctx.sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) checkAnswer( - ctx.sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - ctx.sql("SELECT 1, 2, sum(b) FROM testData2")) + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) } literalInAggTest() @@ -516,62 +514,62 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - ctx.sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), + sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), Row(1, 3, 2, 6, 3) ) } test("select *") { checkAnswer( - ctx.sql("SELECT * FROM testData"), + sql("SELECT * FROM testData"), testData.collect().toSeq) } test("simple select") { checkAnswer( - ctx.sql("SELECT value FROM testData WHERE key = 1"), + sql("SELECT value FROM testData WHERE key = 1"), Row("1")) } def sortTest(): Unit = { checkAnswer( - ctx.sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( - ctx.sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), + sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( - ctx.sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), + sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( - ctx.sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( - ctx.sql("SELECT b FROM binaryData ORDER BY a ASC"), + sql("SELECT b FROM binaryData ORDER BY a ASC"), (1 to 5).map(Row(_))) checkAnswer( - ctx.sql("SELECT b FROM binaryData ORDER BY a DESC"), + sql("SELECT b FROM binaryData ORDER BY a DESC"), (1 to 5).map(Row(_)).toSeq.reverse) checkAnswer( - ctx.sql("SELECT * FROM arrayData ORDER BY data[0] ASC"), + sql("SELECT * FROM arrayData ORDER BY data[0] ASC"), arrayData.collect().sortBy(_.data(0)).map(Row.fromTuple).toSeq) checkAnswer( - ctx.sql("SELECT * FROM arrayData ORDER BY data[0] DESC"), + sql("SELECT * FROM arrayData ORDER BY data[0] DESC"), arrayData.collect().sortBy(_.data(0)).reverse.map(Row.fromTuple).toSeq) checkAnswer( - ctx.sql("SELECT * FROM mapData ORDER BY data[1] ASC"), + sql("SELECT * FROM mapData ORDER BY data[1] ASC"), mapData.collect().sortBy(_.data(1)).map(Row.fromTuple).toSeq) checkAnswer( - ctx.sql("SELECT * FROM mapData ORDER BY data[1] DESC"), + sql("SELECT * FROM mapData ORDER BY data[1] DESC"), mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } @@ -603,25 +601,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("limit") { checkAnswer( - ctx.sql("SELECT * FROM testData LIMIT 10"), + sql("SELECT * FROM testData LIMIT 10"), testData.take(10).toSeq) checkAnswer( - ctx.sql("SELECT * FROM arrayData LIMIT 1"), + sql("SELECT * FROM arrayData LIMIT 1"), arrayData.collect().take(1).map(Row.fromTuple).toSeq) checkAnswer( - ctx.sql("SELECT * FROM mapData LIMIT 1"), + sql("SELECT * FROM mapData LIMIT 1"), mapData.collect().take(1).map(Row.fromTuple).toSeq) } test("CTE feature") { checkAnswer( - ctx.sql("with q1 as (select * from testData limit 10) select * from q1"), + sql("with q1 as (select * from testData limit 10) select * from q1"), testData.take(10).toSeq) checkAnswer( - ctx.sql(""" + sql(""" |with q1 as (select * from testData where key= '5'), |q2 as (select * from testData where key = '4') |select * from q1 union all select * from q2""".stripMargin), @@ -631,20 +629,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Allow only a single WITH clause per query") { intercept[RuntimeException] { - ctx.sql( + sql( "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } } test("date row") { - checkAnswer(ctx.sql( + checkAnswer(sql( """select cast("2015-01-28" as date) from testData limit 1"""), Row(java.sql.Date.valueOf("2015-01-28")) ) } test("from follow multiple brackets") { - checkAnswer(ctx.sql( + checkAnswer(sql( """ |select key from ((select * from testData limit 1) | union all (select * from testData limit 1)) x limit 1 @@ -652,12 +650,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1) ) - checkAnswer(ctx.sql( + checkAnswer(sql( "select key from (select * from testData) x limit 1"), Row(1) ) - checkAnswer(ctx.sql( + checkAnswer(sql( """ |select key from | (select * from testData limit 1 union all select * from testData limit 1) x @@ -669,47 +667,47 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("average") { checkAnswer( - ctx.sql("SELECT AVG(a) FROM testData2"), + sql("SELECT AVG(a) FROM testData2"), Row(2.0)) } test("average overflow") { checkAnswer( - ctx.sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), + sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), Seq(Row(2147483645.0, 1), Row(2.0, 2))) } test("count") { checkAnswer( - ctx.sql("SELECT COUNT(*) FROM testData2"), + sql("SELECT COUNT(*) FROM testData2"), Row(testData2.count())) } test("count distinct") { checkAnswer( - ctx.sql("SELECT COUNT(DISTINCT b) FROM testData2"), + sql("SELECT COUNT(DISTINCT b) FROM testData2"), Row(2)) } test("approximate count distinct") { checkAnswer( - ctx.sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), + sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), Row(3)) } test("approximate count distinct with user provided standard deviation") { checkAnswer( - ctx.sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), + sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), Row(3)) } test("null count") { checkAnswer( - ctx.sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), + sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), Seq(Row(1, 0), Row(2, 1))) checkAnswer( - ctx.sql( + sql( "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), Row(2, 1, 2, 2, 1)) } @@ -718,14 +716,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withTempTable("t") { Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t") checkAnswer( - ctx.sql("select count(a) from t"), + sql("select count(a) from t"), Row(0)) } } test("inner join where, one match per row") { checkAnswer( - ctx.sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), + sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -735,7 +733,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("inner join ON, one match per row") { checkAnswer( - ctx.sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"), + sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -745,7 +743,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("inner join, where, multiple matches") { checkAnswer( - ctx.sql(""" + sql(""" |SELECT * FROM | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y @@ -758,7 +756,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("inner join, no matches") { checkAnswer( - ctx.sql( + sql( """ |SELECT * FROM | (SELECT * FROM testData2 WHERE a = 1) x JOIN @@ -769,7 +767,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("big inner join, 4 matches per row") { checkAnswer( - ctx.sql( + sql( """ |SELECT * FROM | (SELECT * FROM testData UNION ALL @@ -796,7 +794,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("left outer join") { checkAnswer( - ctx.sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"), + sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -807,7 +805,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("right outer join") { checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"), + sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -818,7 +816,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("full outer join") { checkAnswer( - ctx.sql( + sql( """ |SELECT * FROM | (SELECT * FROM upperCaseData WHERE N <= 4) leftTable FULL OUTER JOIN @@ -834,25 +832,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3349 partitioning after limit") { - ctx.sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") + sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) .registerTempTable("subset1") - ctx.sql("SELECT DISTINCT n FROM lowerCaseData") + sql("SELECT DISTINCT n FROM lowerCaseData") .limit(2) .registerTempTable("subset2") checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), + sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), Row(3, "c", 3) :: Row(4, "d", 4) :: Nil) checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), + sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), Row(1, "a", 1) :: Row(2, "b", 2) :: Nil) } test("mixed-case keywords") { checkAnswer( - ctx.sql( + sql( """ |SeleCT * from | (select * from upperCaseData WherE N <= 4) leftTable fuLL OUtER joiN @@ -869,14 +867,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("select with table name as qualifier") { checkAnswer( - ctx.sql("SELECT testData.value FROM testData WHERE testData.key = 1"), + sql("SELECT testData.value FROM testData WHERE testData.key = 1"), Row("1")) } test("inner join ON with table name as qualifier") { checkAnswer( - ctx.sql( - "SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"), + sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -886,7 +883,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("qualified select with inner join ON with table name as qualifier") { checkAnswer( - ctx.sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " + + sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " + "ON lowerCaseData.n = upperCaseData.N"), Seq( Row(1, "A"), @@ -897,7 +894,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("system function upper()") { checkAnswer( - ctx.sql("SELECT n,UPPER(l) FROM lowerCaseData"), + sql("SELECT n,UPPER(l) FROM lowerCaseData"), Seq( Row(1, "A"), Row(2, "B"), @@ -905,7 +902,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(4, "D"))) checkAnswer( - ctx.sql("SELECT n, UPPER(s) FROM nullStrings"), + sql("SELECT n, UPPER(s) FROM nullStrings"), Seq( Row(1, "ABC"), Row(2, "ABC"), @@ -914,7 +911,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("system function lower()") { checkAnswer( - ctx.sql("SELECT N,LOWER(L) FROM upperCaseData"), + sql("SELECT N,LOWER(L) FROM upperCaseData"), Seq( Row(1, "a"), Row(2, "b"), @@ -924,7 +921,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(6, "f"))) checkAnswer( - ctx.sql("SELECT n, LOWER(s) FROM nullStrings"), + sql("SELECT n, LOWER(s) FROM nullStrings"), Seq( Row(1, "abc"), Row(2, "abc"), @@ -933,14 +930,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("UNION") { checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), + sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), + sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), + sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), Row(1, "a") :: Row(1, "a") :: Row(2, "b") :: Row(2, "b") :: Row(3, "c") :: Row(3, "c") :: Row(4, "d") :: Row(4, "d") :: Nil) } @@ -948,63 +945,63 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("UNION with column mismatches") { // Column name mismatches are allowed. checkAnswer( - ctx.sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), + sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) // Column type mismatches are not allowed, forcing a type coercion. checkAnswer( - ctx.sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), + sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) // Column type mismatches where a coercion is not possible, in this case between integer // and array types, trigger a TreeNodeException. intercept[AnalysisException] { - ctx.sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect() + sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect() } } test("EXCEPT") { checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), + sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) + sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) checkAnswer( - ctx.sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil) + sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil) } test("INTERSECT") { checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), + sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( - ctx.sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil) + sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil) } - test("SET commands semantics using ctx.sql()") { + test("SET commands semantics using sql()") { sqlContext.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" // "set" itself returns all config variables currently specified in SQLConf. - assert(ctx.sql("SET").collect().size == 0) + assert(sql("SET").collect().size == 0) // "set key=val" - ctx.sql(s"SET $testKey=$testVal") + sql(s"SET $testKey=$testVal") checkAnswer( - ctx.sql("SET"), + sql("SET"), Row(testKey, testVal) ) - ctx.sql(s"SET ${testKey + testKey}=${testVal + testVal}") + sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( - ctx.sql("set"), + sql("set"), Seq( Row(testKey, testVal), Row(testKey + testKey, testVal + testVal)) @@ -1012,11 +1009,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // "set key" checkAnswer( - ctx.sql(s"SET $testKey"), + sql(s"SET $testKey"), Row(testKey, testVal) ) checkAnswer( - ctx.sql(s"SET $nonexistentKey"), + sql(s"SET $nonexistentKey"), Row(nonexistentKey, "") ) sqlContext.conf.clear() @@ -1026,9 +1023,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sqlContext.conf.clear() // Set negative mapred.reduce.tasks for automatically determing // the number of reducers is not supported - intercept[IllegalArgumentException](ctx.sql(s"SET mapred.reduce.tasks=-1")) - intercept[IllegalArgumentException](ctx.sql(s"SET mapred.reduce.tasks=-01")) - intercept[IllegalArgumentException](ctx.sql(s"SET mapred.reduce.tasks=-2")) + intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) + intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) + intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) sqlContext.conf.clear() } @@ -1050,14 +1047,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( - ctx.sql("SELECT * FROM applySchema1"), + sql("SELECT * FROM applySchema1"), Row(1, "A1", true, null) :: Row(2, "B2", false, null) :: Row(3, "C3", true, null) :: Row(4, "D4", true, 2147483644) :: Nil) checkAnswer( - ctx.sql("SELECT f1, f4 FROM applySchema1"), + sql("SELECT f1, f4 FROM applySchema1"), Row(1, null) :: Row(2, null) :: Row(3, null) :: @@ -1080,14 +1077,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val df2 = sqlContext.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( - ctx.sql("SELECT * FROM applySchema2"), + sql("SELECT * FROM applySchema2"), Row(Row(1, true), Map("A1" -> null)) :: Row(Row(2, false), Map("B2" -> null)) :: Row(Row(3, true), Map("C3" -> null)) :: Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil) checkAnswer( - ctx.sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), + sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), Row(1, null) :: Row(2, null) :: Row(3, null) :: @@ -1106,7 +1103,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df3.registerTempTable("applySchema3") checkAnswer( - ctx.sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), + sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), Row(1, null) :: Row(2, null) :: Row(3, null) :: @@ -1115,17 +1112,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-3423 BETWEEN") { checkAnswer( - ctx.sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"), + sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"), Seq(Row(5, "5"), Row(6, "6"), Row(7, "7")) ) checkAnswer( - ctx.sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"), + sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"), Row(7, "7") ) checkAnswer( - ctx.sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), + sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), Nil ) } @@ -1133,12 +1130,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("cast boolean to string") { // TODO Ensure true/false string letter casing is consistent with Hive in all cases. checkAnswer( - ctx.sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), + sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), Row("true", "false")) } test("metadata is propagated correctly") { - val person: DataFrame = ctx.sql("SELECT * FROM person") + val person: DataFrame = sql("SELECT * FROM person") val schema = person.schema val docKey = "doc" val docValue = "first name" @@ -1155,41 +1152,40 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { validateMetadata(personWithMeta.select($"name")) validateMetadata(personWithMeta.select($"name")) validateMetadata(personWithMeta.select($"id", $"name")) - validateMetadata(ctx.sql("SELECT * FROM personWithMeta")) - validateMetadata(ctx.sql("SELECT id, name FROM personWithMeta")) - validateMetadata(ctx.sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) - validateMetadata(ctx.sql( + validateMetadata(sql("SELECT * FROM personWithMeta")) + validateMetadata(sql("SELECT id, name FROM personWithMeta")) + validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql( "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } test("SPARK-3371 Renaming a function expression with group by gives error") { sqlContext.udf.register("len", (s: String) => s.length) checkAnswer( - ctx.sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) } test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { checkAnswer( - ctx.sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), + sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), Row(1)) } test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { checkAnswer( - ctx.sql( - "SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), + sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), Row(1)) } test("throw errors for non-aggregate attributes with aggregation") { def checkAggregation(query: String, isInvalidQuery: Boolean = true) { if (isInvalidQuery) { - val e = intercept[AnalysisException](ctx.sql(query).queryExecution.analyzed) + val e = intercept[AnalysisException](sql(query).queryExecution.analyzed) assert(e.getMessage contains "group by") } else { // Should not throw - ctx.sql(query).queryExecution.analyzed + sql(query).queryExecution.analyzed } } @@ -1205,137 +1201,137 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Test to check we can use Long.MinValue") { checkAnswer( - ctx.sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue) + sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue) ) checkAnswer( - ctx.sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), + sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), (1 to 100).map(Row(_)).toSeq ) } test("Floating point number format") { checkAnswer( - ctx.sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) + sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) ) checkAnswer( - ctx.sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) + sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) ) checkAnswer( - ctx.sql("SELECT .5"), Row(BigDecimal(0.5)) + sql("SELECT .5"), Row(BigDecimal(0.5)) ) checkAnswer( - ctx.sql("SELECT -.18"), Row(BigDecimal(-0.18)) + sql("SELECT -.18"), Row(BigDecimal(-0.18)) ) } test("Auto cast integer type") { checkAnswer( - ctx.sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L) + sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L) ) checkAnswer( - ctx.sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L) + sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L) ) checkAnswer( - ctx.sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) + sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) ) checkAnswer( - ctx.sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) + sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) ) } test("Test to check we can apply sign to expression") { checkAnswer( - ctx.sql("SELECT -100"), Row(-100) + sql("SELECT -100"), Row(-100) ) checkAnswer( - ctx.sql("SELECT +230"), Row(230) + sql("SELECT +230"), Row(230) ) checkAnswer( - ctx.sql("SELECT -5.2"), Row(BigDecimal(-5.2)) + sql("SELECT -5.2"), Row(BigDecimal(-5.2)) ) checkAnswer( - ctx.sql("SELECT +6.8"), Row(BigDecimal(6.8)) + sql("SELECT +6.8"), Row(BigDecimal(6.8)) ) checkAnswer( - ctx.sql("SELECT -key FROM testData WHERE key = 2"), Row(-2) + sql("SELECT -key FROM testData WHERE key = 2"), Row(-2) ) checkAnswer( - ctx.sql("SELECT +key FROM testData WHERE key = 3"), Row(3) + sql("SELECT +key FROM testData WHERE key = 3"), Row(3) ) checkAnswer( - ctx.sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2) + sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2) ) checkAnswer( - ctx.sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9) + sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9) ) checkAnswer( - ctx.sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10) + sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10) ) checkAnswer( - ctx.sql("SELECT -MAX(key) FROM testData"), Row(-100) + sql("SELECT -MAX(key) FROM testData"), Row(-100) ) checkAnswer( - ctx.sql("SELECT +MAX(key) FROM testData"), Row(100) + sql("SELECT +MAX(key) FROM testData"), Row(100) ) checkAnswer( - ctx.sql("SELECT - (-10)"), Row(10) + sql("SELECT - (-10)"), Row(10) ) checkAnswer( - ctx.sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32) + sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32) ) checkAnswer( - ctx.sql("SELECT - (+Max(key)) FROM testData"), Row(-100) + sql("SELECT - (+Max(key)) FROM testData"), Row(-100) ) checkAnswer( - ctx.sql("SELECT - - 3"), Row(3) + sql("SELECT - - 3"), Row(3) ) checkAnswer( - ctx.sql("SELECT - + 20"), Row(-20) + sql("SELECT - + 20"), Row(-20) ) checkAnswer( - ctx.sql("SELEcT - + 45"), Row(-45) + sql("SELEcT - + 45"), Row(-45) ) checkAnswer( - ctx.sql("SELECT + + 100"), Row(100) + sql("SELECT + + 100"), Row(100) ) checkAnswer( - ctx.sql("SELECT - - Max(key) FROM testData"), Row(100) + sql("SELECT - - Max(key) FROM testData"), Row(100) ) checkAnswer( - ctx.sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33) + sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33) ) } test("Multiple join") { checkAnswer( - ctx.sql( + sql( """SELECT a.key, b.key, c.key |FROM testData a |JOIN testData b ON a.key = b.key @@ -1348,28 +1344,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val data = sqlContext.sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) sqlContext.read.json(data).registerTempTable("records") - ctx.sql("SELECT `key?number1`, `key.number2` FROM records") + sql("SELECT `key?number1`, `key.number2` FROM records") } test("SPARK-3814 Support Bitwise & operator") { - checkAnswer(ctx.sql("SELECT key&1 FROM testData WHERE key = 1 "), Row(1)) + checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise | operator") { - checkAnswer(ctx.sql("SELECT key|0 FROM testData WHERE key = 1 "), Row(1)) + checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ^ operator") { - checkAnswer(ctx.sql("SELECT key^0 FROM testData WHERE key = 1 "), Row(1)) + checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ~ operator") { - checkAnswer(ctx.sql("SELECT ~key FROM testData WHERE key = 1 "), Row(-2)) + checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), Row(-2)) } test("SPARK-4120 Join of multiple tables does not work in SparkSQL") { checkAnswer( - ctx.sql( + sql( """SELECT a.key, b.key, c.key |FROM testData a,testData b,testData c |where a.key = b.key and a.key = c.key @@ -1378,37 +1374,37 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { - checkAnswer(ctx.sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), + checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), (11 to 100).map(i => Row(i))) } test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { - checkAnswer(ctx.sql("SELECT key FROM testData WHERE value not like '100%' order by key"), + checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), (1 to 99).map(i => Row(i))) } test("SPARK-4322 Grouping field with struct field as sub expression") { sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") - checkAnswer(ctx.sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) + checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) sqlContext.dropTempTable("data") sqlContext.read.json( sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") - checkAnswer(ctx.sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) + checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) sqlContext.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( - ctx.sql("SELECT a + b FROM testData2 ORDER BY a"), + sql("SELECT a + b FROM testData2 ORDER BY a"), Seq(2, 3, 3, 4, 4, 5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( - ctx.sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), + sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)) ) } @@ -1420,7 +1416,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") - checkAnswer(ctx.sql("SELECT nulldata1.key FROM nulldata1 join " + + checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), (1 to 2).map(i => Row(i))) } @@ -1429,7 +1425,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") - checkAnswer(ctx.sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) + checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } test("SPARK-4699 case sensitivity SQL query") { @@ -1437,7 +1433,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") - checkAnswer(ctx.sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) + checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } @@ -1446,19 +1442,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") - checkAnswer(ctx.sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) - checkAnswer(ctx.sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) - checkAnswer(ctx.sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) - checkAnswer(ctx.sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) - checkAnswer(ctx.sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) - checkAnswer(ctx.sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) } test("SPARK-6145: special cases") { sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") - checkAnswer(ctx.sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) - checkAnswer(ctx.sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { @@ -1466,7 +1462,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") - checkAnswer(ctx.sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } test("SPARK-6583 order by aggregated function") { @@ -1474,7 +1470,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .toDF("a", "b").registerTempTable("orderByData") checkAnswer( - ctx.sql( + sql( """ |SELECT a |FROM orderByData @@ -1484,7 +1480,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row("4") :: Row("1") :: Row("3") :: Row("2") :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT sum(b) |FROM orderByData @@ -1494,7 +1490,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT a, sum(b) |FROM orderByData @@ -1504,7 +1500,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT a, sum(b) |FROM orderByData @@ -1529,8 +1525,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { (null, null, null, true) ).toDF("i", "b", "r1", "r2").registerTempTable("t") - checkAnswer(ctx.sql("select i = b from t"), ctx.sql("select r1 from t")) - checkAnswer(ctx.sql("select i <=> b from t"), ctx.sql("select r2 from t")) + checkAnswer(sql("select i = b from t"), sql("select r1 from t")) + checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) } } @@ -1538,14 +1534,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withTempTable("t") { sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") - checkAnswer(ctx.sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) + checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } } test("SPARK-8782: ORDER BY NULL") { withTempTable("t") { Seq((1, 2), (1, 2)).toDF("a", "b").registerTempTable("t") - checkAnswer(ctx.sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) + checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) } } @@ -1554,14 +1550,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val df = Seq(1 -> "a").toDF("count", "sort") checkAnswer(df.filter("count > 0"), Row(1, "a")) df.registerTempTable("t") - checkAnswer(ctx.sql("select count, sort from t"), Row(1, "a")) + checkAnswer(sql("select count, sort from t"), Row(1, "a")) } } test("SPARK-8753: add interval type") { import org.apache.spark.unsafe.types.CalendarInterval - val df = ctx.sql("select interval 3 years -3 month 7 week 123 microseconds") + val df = sql("select interval 3 years -3 month 7 week 123 microseconds") checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) withTempPath(f => { // Currently we don't yet support saving out values of interval data type. @@ -1573,7 +1569,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { def checkIntervalParseError(s: String): Unit = { val e = intercept[AnalysisException] { - ctx.sql(s) + sql(s) } e.message.contains("at least one time unit should be given for interval literal") } @@ -1587,7 +1583,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.unsafe.types.CalendarInterval.MICROS_PER_WEEK - val df = ctx.sql("select interval 3 years -3 month 7 week 123 microseconds as i") + val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) checkAnswer(df.select(df("i") + new CalendarInterval(2, 123)), @@ -1626,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .toDF("num", "str") df.registerTempTable("1one") - checkAnswer(ctx.sql("select count(num) from 1one"), Row(10)) + checkAnswer(sql("select count(num) from 1one"), Row(10)) sqlContext.dropTempTable("1one") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index c1ae8d04fab1..295f02f9a7b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -80,7 +80,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) Seq(data).toDF().registerTempTable("reflectData") - assert(ctx.sql("SELECT * FROM reflectData").collect().head === + assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -90,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { val data = NullReflectData(null, null, null, null, null, null, null) Seq(data).toDF().registerTempTable("reflectNullData") - assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -98,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { val data = OptionalReflectData(None, None, None, None, None, None, None) Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -106,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { test("query binary data") { Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = ctx.sql("SELECT data FROM reflectBinary") + val result = sql("SELECT data FROM reflectBinary") .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -125,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { Nested(None, "abc"))) Seq(data).toDF().registerTempTable("reflectComplexData") - assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === + assert(sql("SELECT * FROM reflectComplexData").collect().head === Row( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 46056c16533f..eb275af101e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -54,7 +54,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { test("SPARK-8003 spark_partition_id") { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") - checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) + checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) ctx.dropTempTable("tmp_table") } @@ -63,9 +63,9 @@ class UDFSuite extends QueryTest with SharedSQLContext { val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") - val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) - assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) ctx.dropTempTable("test_table") } } @@ -88,17 +88,17 @@ class UDFSuite extends QueryTest with SharedSQLContext { test("Simple UDF") { ctx.udf.register("strLenScala", (_: String).length) - assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { ctx.udf.register("random0", () => { Math.random()}) - assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { ctx.udf.register("strLenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { @@ -109,7 +109,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { df.registerTempTable("integerData") val result = - ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") assert(result.count() === 20) } @@ -121,7 +121,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT g, SUM(v) as s | FROM groupData @@ -140,7 +140,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT SUM(v) | FROM groupData @@ -160,7 +160,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT timesHundred(SUM(v)) as v100 | FROM groupData @@ -175,7 +175,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - ctx.sql("SELECT returnStruct('test', 'test2') as ret") + sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } @@ -183,12 +183,12 @@ class UDFSuite extends QueryTest with SharedSQLContext { test("udf that is transformed") { ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { ctx.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. - assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index de637628debc..b6d279ae4726 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -93,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - ctx.sql("SELECT testType(features) from points"), + sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 261a1878ac7f..952637c5f9cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -66,25 +66,25 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-1678 regression: compression must not lose repeated values") { checkAnswer( - ctx.sql("SELECT * FROM repeatedData"), + sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) ctx.cacheTable("repeatedData") checkAnswer( - ctx.sql("SELECT * FROM repeatedData"), + sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) } test("with null values") { checkAnswer( - ctx.sql("SELECT * FROM nullableRepeatedData"), + sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) ctx.cacheTable("nullableRepeatedData") checkAnswer( - ctx.sql("SELECT * FROM nullableRepeatedData"), + sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) } @@ -93,25 +93,25 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { timestamps.registerTempTable("timestamps") checkAnswer( - ctx.sql("SELECT time FROM timestamps"), + sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) ctx.cacheTable("timestamps") checkAnswer( - ctx.sql("SELECT time FROM timestamps"), + sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) } test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { checkAnswer( - ctx.sql("SELECT * FROM withEmptyParts"), + sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) ctx.cacheTable("withEmptyParts") checkAnswer( - ctx.sql("SELECT * FROM withEmptyParts"), + sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) } @@ -133,7 +133,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { df.cache().registerTempTable("test_fixed_decimal") checkAnswer( - ctx.sql("SELECT * FROM test_fixed_decimal"), + sql("SELECT * FROM test_fixed_decimal"), (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) } @@ -179,7 +179,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. - ctx.sql("cache table InMemoryCache_different_data_types") + sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( @@ -187,7 +187,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( - ctx.sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), + sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), ctx.table("InMemoryCache_different_data_types").collect()) ctx.dropTempTable("InMemoryCache_different_data_types") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index fb9ff2f50325..ab2644eb4581 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -107,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = ctx.sql(query) + val df = sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 1f36c28bd28b..937a10854353 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -91,7 +91,7 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { val rowRDD = ctx.sparkContext.parallelize(row :: Nil) ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit") - val planned = ctx.sql( + val planned = sql( """ |SELECT l.a, l.b |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) @@ -146,7 +146,7 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) testData.limit(3).registerTempTable("tiny") - ctx.sql("CACHE TABLE tiny") + sql("CACHE TABLE tiny") val a = testData.as("a") val b = ctx.table("tiny").as("b") @@ -176,7 +176,7 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { // Disable broadcast join withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { { - val numExchanges = ctx.sql( + val numExchanges = sql( """ |SELECT * |FROM @@ -191,7 +191,7 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { { // This second query joins on different keys: - val numExchanges = ctx.sql( + val numExchanges = sql( """ |SELECT * |FROM diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 636f55763a21..1174b27732f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -228,7 +228,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql("select nullstr, headers.Host from jsonTable"), + sql("select nullstr, headers.Host from jsonTable"), Seq(Row("", "1.abc.com"), Row("", null), Row("", null), Row(null, null)) ) } @@ -250,7 +250,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql("select * from jsonTable"), + sql("select * from jsonTable"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -292,45 +292,44 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Access elements of a primitive array. checkAnswer( - ctx.sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( - ctx.sql("select arrayOfNull from jsonTable"), + sql("select arrayOfNull from jsonTable"), Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( - ctx.sql( - "select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), Row(new java.math.BigDecimal("922337203685477580700"), new java.math.BigDecimal("-922337203685477580800"), null) ) // Access elements of an array of arrays. checkAnswer( - ctx.sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), + sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( - ctx.sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), + sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( - ctx.sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), Row("str2", 2.1) ) // Access elements of an array of structs. checkAnswer( - ctx.sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + + sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + "from jsonTable"), Row( Row(true, "str1", null), @@ -341,7 +340,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Access a struct and fields inside of it. checkAnswer( - ctx.sql("select struct, struct.field1, struct.field2 from jsonTable"), + sql("select struct, struct.field1, struct.field2 from jsonTable"), Row( Row(true, new java.math.BigDecimal("92233720368547758070")), true, @@ -350,14 +349,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Access an array field of a struct. checkAnswer( - ctx.sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( - ctx.sql( - "select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), Row(5, null) ) } @@ -367,13 +365,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), Row(true, "str1") ) // Getting all values of a specific field from an array of structs. checkAnswer( - ctx.sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), + sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), Row(Seq(true, false, null), Seq("str1", null, null)) ) } @@ -394,7 +392,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql("select * from jsonTable"), + sql("select * from jsonTable"), Row("true", 11L, null, 1.1, "13.1", "str1") :: Row("12", null, 21474836470.9, null, null, "true") :: Row("false", 21474836470L, 92233720368547758070d, 100, "str1", "false") :: @@ -403,49 +401,49 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Number and Boolean conflict: resolve the type as number in this query. checkAnswer( - ctx.sql("select num_bool - 10 from jsonTable where num_bool > 11"), + sql("select num_bool - 10 from jsonTable where num_bool > 11"), Row(2) ) // Widening to LongType checkAnswer( - ctx.sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), + sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), Row(21474836370L) :: Row(21474836470L) :: Nil ) checkAnswer( - ctx.sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), + sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) // Widening to DecimalType checkAnswer( - ctx.sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), + sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), Row(21474836472.2) :: Row(92233720368547758071.3) :: Nil ) // Widening to Double checkAnswer( - ctx.sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), + sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), Row(101.2) :: Row(21474836471.2) :: Nil ) // Number and String conflict: resolve the type as number in this query. checkAnswer( - ctx.sql("select num_str + 1.2 from jsonTable where num_str > 14"), + sql("select num_str + 1.2 from jsonTable where num_str > 14"), Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( - ctx.sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), Row(new java.math.BigDecimal("92233720368547758071.2")) ) // String and Boolean conflict: resolve the type as string. checkAnswer( - ctx.sql("select * from jsonTable where str_bool = 'str1'"), + sql("select * from jsonTable where str_bool = 'str1'"), Row("true", 11L, null, 1.1, "13.1", "str1") ) } @@ -457,24 +455,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Right now, the analyzer does not promote strings in a boolean expression. // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( - ctx.sql("select num_bool from jsonTable where NOT num_bool"), + sql("select num_bool from jsonTable where NOT num_bool"), Row(false) ) checkAnswer( - ctx.sql("select str_bool from jsonTable where NOT str_bool"), + sql("select str_bool from jsonTable where NOT str_bool"), Row(false) ) // Right now, the analyzer does not know that num_bool should be treated as a boolean. // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( - ctx.sql("select num_bool from jsonTable where num_bool"), + sql("select num_bool from jsonTable where num_bool"), Row(true) ) checkAnswer( - ctx.sql("select str_bool from jsonTable where str_bool"), + sql("select str_bool from jsonTable where str_bool"), Row(false) ) @@ -498,7 +496,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // which is not 14.3. // Number and String conflict: resolve the type as number in this query. checkAnswer( - ctx.sql("select num_str + 1.2 from jsonTable where num_str > 13"), + sql("select num_str + 1.2 from jsonTable where num_str > 13"), Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil ) } @@ -519,7 +517,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql("select * from jsonTable"), + sql("select * from jsonTable"), Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: Row(null, """{"field":false}""", null, null, "{}") :: Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: @@ -541,7 +539,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql("select * from jsonTable"), + sql("select * from jsonTable"), Row(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", """{"field":"str"}"""), Seq(Row(214748364700L), Row(1)), null) :: Row(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: @@ -550,7 +548,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Treat an element as a number. checkAnswer( - ctx.sql("select array1[0] + 1 from jsonTable where array1 is not null"), + sql("select array1[0] + 1 from jsonTable where array1 is not null"), Row(2) ) } @@ -621,7 +619,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql("select * from jsonTable"), + sql("select * from jsonTable"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -638,7 +636,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTableSQL |USING org.apache.spark.sql.json @@ -648,7 +646,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { """.stripMargin) checkAnswer( - ctx.sql("select * from jsonTableSQL"), + sql("select * from jsonTableSQL"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -681,7 +679,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF1.registerTempTable("jsonTable1") checkAnswer( - ctx.sql("select * from jsonTable1"), + sql("select * from jsonTable1"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -698,7 +696,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF2.registerTempTable("jsonTable2") checkAnswer( - ctx.sql("select * from jsonTable2"), + sql("select * from jsonTable2"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -717,7 +715,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") checkAnswer( - ctx.sql("select map from jsonWithSimpleMap"), + sql("select map from jsonWithSimpleMap"), Row(Map("a" -> 1)) :: Row(Map("b" -> 2)) :: Row(Map("c" -> 3)) :: @@ -726,7 +724,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - ctx.sql("select map['c'] from jsonWithSimpleMap"), + sql("select map['c'] from jsonWithSimpleMap"), Row(null) :: Row(null) :: Row(3) :: @@ -745,7 +743,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonWithComplexMap.registerTempTable("jsonWithComplexMap") checkAnswer( - ctx.sql("select map from jsonWithComplexMap"), + sql("select map from jsonWithComplexMap"), Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) :: Row(Map("b" -> Row(null, 2))) :: Row(Map("c" -> Row(Seq(), 4))) :: @@ -755,7 +753,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - ctx.sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), + sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), Row(Seq(1, 2, 3, null), null) :: Row(null, null) :: Row(null, 4) :: @@ -770,11 +768,11 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), Row(true, "str1") ) checkAnswer( - ctx.sql( + sql( """ |select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] |from jsonTable @@ -788,7 +786,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql( + sql( """ |select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0] |from jsonTable @@ -796,7 +794,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(5, 7, 8) ) checkAnswer( - ctx.sql( + sql( """ |select arrayOfArray2[0][0][0].inner1, arrayOfArray2[1][0], |arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4 @@ -811,7 +809,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") checkAnswer( - ctx.sql( + sql( """ |select a, b, c |from jsonTable @@ -841,7 +839,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // In HiveContext, backticks should be used to access columns starting with a underscore. checkAnswer( - ctx.sql( + sql( """ |SELECT a, b, c, _unparsed |FROM jsonTable @@ -855,7 +853,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - ctx.sql( + sql( """ |SELECT a, b, c |FROM jsonTable @@ -865,7 +863,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - ctx.sql( + sql( """ |SELECT _unparsed |FROM jsonTable @@ -900,7 +898,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(schema === jsonDF.schema) checkAnswer( - ctx.sql( + sql( """ |SELECT field1, field2, field3, field4 |FROM jsonTable @@ -963,7 +961,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val primTable = ctx.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( - ctx.sql("select * from primativeTable"), + sql("select * from primativeTable"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -977,19 +975,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( - ctx.sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( - ctx.sql("select arrayOfNull from complexTable"), + sql("select arrayOfNull from complexTable"), Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( - ctx.sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] " + + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] " + " from complexTable"), Row(new java.math.BigDecimal("922337203685477580700"), new java.math.BigDecimal("-922337203685477580800"), null) @@ -997,25 +995,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Access elements of an array of arrays. checkAnswer( - ctx.sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), + sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( - ctx.sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), + sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( - ctx.sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), Row("str2", 2.1) ) // Access a struct and fields inside of it. checkAnswer( - ctx.sql("select struct, struct.field1, struct.field2 from complexTable"), + sql("select struct, struct.field1, struct.field2 from complexTable"), Row( Row(true, new java.math.BigDecimal("92233720368547758070")), true, @@ -1024,14 +1022,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Access an array field of a struct. checkAnswer( - ctx.sql( - "select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( - ctx.sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] " + + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] " + "from complexTable"), Row(5, null) ) @@ -1158,11 +1155,11 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "abd") ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") - checkAnswer(ctx.sql( + checkAnswer(sql( "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) - checkAnswer(ctx.sql( + checkAnswer(sql( "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) - checkAnswer(ctx.sql( + checkAnswer(sql( "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index be3afaa87abc..ed8bafb10c60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -282,7 +282,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempTable("t") { checkAnswer( - ctx.sql("SELECT * FROM t"), + sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -290,7 +290,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } yield Row(i, i.toString, pi, ps)) checkAnswer( - ctx.sql("SELECT intField, pi FROM t"), + sql("SELECT intField, pi FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -298,14 +298,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } yield Row(i, pi)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE pi = 1"), + sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") } yield Row(i, i.toString, 1, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE ps = 'foo'"), + sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -329,7 +329,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempTable("t") { checkAnswer( - ctx.sql("SELECT * FROM t"), + sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -337,7 +337,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } yield Row(i, pi, i.toString, ps)) checkAnswer( - ctx.sql("SELECT intField, pi FROM t"), + sql("SELECT intField, pi FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -345,14 +345,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } yield Row(i, pi)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE pi = 1"), + sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") } yield Row(i, 1, i.toString, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE ps = 'foo'"), + sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -378,7 +378,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempTable("t") { checkAnswer( - ctx.sql("SELECT * FROM t"), + sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) @@ -386,14 +386,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } yield Row(i, i.toString, pi, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE pi IS NULL"), + sql("SELECT * FROM t WHERE pi IS NULL"), for { i <- 1 to 10 ps <- Seq("foo", null.asInstanceOf[String]) } yield Row(i, i.toString, null, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE ps IS NULL"), + sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) @@ -418,7 +418,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempTable("t") { checkAnswer( - ctx.sql("SELECT * FROM t"), + sql("SELECT * FROM t"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -426,7 +426,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } yield Row(i, pi, i.toString, ps)) checkAnswer( - ctx.sql("SELECT * FROM t WHERE ps IS NULL"), + sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, 2) @@ -454,7 +454,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempTable("t") { checkAnswer( - ctx.sql("SELECT * FROM t"), + sql("SELECT * FROM t"), (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 11c5818657a1..e2f2a8c74478 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -33,9 +33,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(ctx.sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) - checkAnswer( - ctx.sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) } } @@ -43,7 +42,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val data = (0 until 10).map(i => (i, i.toString)) ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { - ctx.sql("INSERT INTO TABLE t SELECT * FROM tmp") + sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) } ctx.catalog.unregisterTable(Seq("tmp")) @@ -53,7 +52,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val data = (0 until 10).map(i => (i, i.toString)) ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { - ctx.sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) } ctx.catalog.unregisterTable(Seq("tmp")) @@ -67,7 +66,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } withParquetTable(data, "t") { - val selfJoin = ctx.sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) @@ -82,7 +81,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("nested data - struct with array field") { val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) withParquetTable(data, "t") { - checkAnswer(ctx.sql("SELECT _1._2[0] FROM t"), data.map { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) }) } @@ -91,7 +90,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("nested data - array of struct") { val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) withParquetTable(data, "t") { - checkAnswer(ctx.sql("SELECT _1[0]._2 FROM t"), data.map { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) }) } @@ -99,17 +98,17 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(ctx.sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } test("SPARK-5309 strings stored using dictionary compression in parquet") { withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - checkAnswer(ctx.sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), (0 until 10).map(i => Row("same", "run_" + i, 100))) - checkAnswer(ctx.sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), List(Row("same", "run_5", 100))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 3d611147b8e5..0edac0848c3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -64,14 +64,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext "insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() conn.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE foobar |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE fetchtwo |USING org.apache.spark.sql.jdbc @@ -79,7 +79,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext | fetchSize '2') """.stripMargin.replaceAll("\n", " ")) - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE parts |USING org.apache.spark.sql.jdbc @@ -94,7 +94,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext conn.prepareStatement("insert into test.inttypes values (null, null, null, null, null)" ).executeUpdate() conn.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE inttypes |USING org.apache.spark.sql.jdbc @@ -111,7 +111,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext stmt.setBytes(5, testBytes) stmt.setString(6, "I am a clob!") stmt.executeUpdate() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE strtypes |USING org.apache.spark.sql.jdbc @@ -125,7 +125,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext conn.prepareStatement("insert into test.timetypes values ('12:34:56', " + "null, '2002-02-20 11:22:33.543543543')").executeUpdate() conn.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE timetypes |USING org.apache.spark.sql.jdbc @@ -140,7 +140,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext + "1.00000011920928955078125, " + "123456789012345.543215432154321)").executeUpdate() conn.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE flttypes |USING org.apache.spark.sql.jdbc @@ -157,7 +157,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext + "null, null, null, null, null, null, null, null, null, " + "null, null, null, null, null, null)").executeUpdate() conn.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE nulltypes |USING org.apache.spark.sql.jdbc @@ -172,26 +172,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT *") { - assert(ctx.sql("SELECT * FROM foobar").collect().size === 3) + assert(sql("SELECT * FROM foobar").collect().size === 3) } test("SELECT * WHERE (simple predicates)") { - assert(ctx.sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) - assert(ctx.sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) - assert(ctx.sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) - assert(ctx.sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) - assert(ctx.sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) - assert(ctx.sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) } test("SELECT * WHERE (quoted strings)") { - assert( - ctx.sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) + assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) } test("SELECT first field") { - val names = - ctx.sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) + val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size === 3) assert(names(0).equals("fred")) assert(names(1).equals("joe 'foo' \"bar\"")) @@ -199,8 +197,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT first field when fetchSize is two") { - val names = - ctx.sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) + val names = sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size === 3) assert(names(0).equals("fred")) assert(names(1).equals("joe 'foo' \"bar\"")) @@ -208,7 +205,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT second field") { - val ids = ctx.sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) + val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) assert(ids(0) === 1) assert(ids(1) === 2) @@ -216,7 +213,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT second field when fetchSize is two") { - val ids = ctx.sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) + val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) assert(ids(0) === 1) assert(ids(1) === 2) @@ -224,17 +221,17 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT * partitioned") { - assert(ctx.sql("SELECT * FROM parts").collect().size == 3) + assert(sql("SELECT * FROM parts").collect().size == 3) } test("SELECT WHERE (simple predicates) partitioned") { - assert(ctx.sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) - assert(ctx.sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) - assert(ctx.sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) + assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) } test("SELECT second field partitioned") { - val ids = ctx.sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _) + val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) assert(ids(0) === 1) assert(ids(1) === 2) @@ -243,7 +240,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("Register JDBC query with renamed fields") { // Regression test for bug SPARK-7345 - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE renamed |USING org.apache.spark.sql.jdbc @@ -251,7 +248,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext |user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - val df = ctx.sql("SELECT * FROM renamed") + val df = sql("SELECT * FROM renamed") assert(df.schema.fields.size == 2) assert(df.schema.fields(0).name == "NAME1") assert(df.schema.fields(1).name == "NAME2") @@ -282,7 +279,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 integral types") { - val rows = ctx.sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() + val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() assert(rows.length === 1) assert(rows(0).getInt(0) === 1) assert(rows(0).getBoolean(1) === false) @@ -292,7 +289,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 null entries") { - val rows = ctx.sql("SELECT * FROM inttypes WHERE A IS NULL").collect() + val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() assert(rows.length === 1) assert(rows(0).isNullAt(0)) assert(rows(0).isNullAt(1)) @@ -302,7 +299,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 string types") { - val rows = ctx.sql("SELECT * FROM strtypes").collect() + val rows = sql("SELECT * FROM strtypes").collect() assert(rows(0).getAs[Array[Byte]](0).sameElements(testBytes)) assert(rows(0).getString(1).equals("Sensitive")) assert(rows(0).getString(2).equals("Insensitive")) @@ -312,7 +309,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 time types") { - val rows = ctx.sql("SELECT * FROM timetypes").collect() + val rows = sql("SELECT * FROM timetypes").collect() val cal = new GregorianCalendar(java.util.Locale.ROOT) cal.setTime(rows(0).getAs[java.sql.Timestamp](0)) assert(cal.get(Calendar.HOUR_OF_DAY) === 12) @@ -346,7 +343,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") - val cachedRows = ctx.sql("select * from mycached_date").collect() + val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } @@ -358,26 +355,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("H2 floating-point types") { - val rows = ctx.sql("SELECT * FROM flttypes").collect() + val rows = sql("SELECT * FROM flttypes").collect() assert(rows(0).getDouble(0) === 1.00000000000000022) assert(rows(0).getDouble(1) === 1.00000011920928955) assert(rows(0).getAs[BigDecimal](2) === new BigDecimal("123456789012345.543215432154321000")) assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18)) - val result = ctx.sql("SELECT C FROM flttypes where C > C - 1").collect() + val result = sql("SELECT C FROM flttypes where C > C - 1").collect() assert(result(0).getAs[BigDecimal](0) === new BigDecimal("123456789012345.543215432154321000")) } test("SQL query as table name") { - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE hack |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)', | user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - val rows = ctx.sql("SELECT * FROM hack").collect() + val rows = sql("SELECT * FROM hack").collect() assert(rows(0).getDouble(0) === 1.00000011920928955) // Yes, I meant ==. // For some reason, H2 computes this square incorrectly... assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) @@ -387,7 +384,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext // We set rowId to false during setup, which means that _ROWID_ column should be absent from // all tables. If rowId is true (default), the query below doesn't throw an exception. intercept[JdbcSQLException] { - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE abc |USING org.apache.spark.sql.jdbc diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 78f521d380f2..5dc3a2c07b8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -56,14 +56,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc @@ -142,14 +142,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon } test("INSERT to JDBC Datasource") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 59e363ef8dbd..9bc3f6bcf6fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.util.Utils class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null @@ -51,7 +52,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("CREATE TEMPORARY TABLE AS SELECT") { - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING json @@ -62,8 +63,8 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), - caseInsensitiveContext.sql("SELECT a, b FROM jt").collect()) + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt").collect()) caseInsensitiveContext.dropTempTable("jsonTable") } @@ -75,7 +76,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with path.setWritable(false) val e = intercept[IOException] { - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING json @@ -84,7 +85,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with |) AS |SELECT a, b FROM jt """.stripMargin) - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable").collect() + sql("SELECT a, b FROM jsonTable").collect() } assert(e.getMessage().contains("Unable to clear output directory")) @@ -92,7 +93,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("create a table, drop it and create another one with the same name") { - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING json @@ -103,11 +104,11 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), - caseInsensitiveContext.sql("SELECT a, b FROM jt").collect()) + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt").collect()) val message = intercept[DDLException]{ - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable |USING json @@ -122,7 +123,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") // Overwrite the temporary table. - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING json @@ -132,14 +133,14 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with |SELECT a * 4 FROM jt """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT * FROM jsonTable"), - caseInsensitiveContext.sql("SELECT a * 4 FROM jt").collect()) + sql("SELECT * FROM jsonTable"), + sql("SELECT a * 4 FROM jt").collect()) caseInsensitiveContext.dropTempTable("jsonTable") // Explicitly delete the data. if (path.exists()) Utils.deleteRecursively(path) - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING json @@ -150,15 +151,15 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT * FROM jsonTable"), - caseInsensitiveContext.sql("SELECT b FROM jt").collect()) + sql("SELECT * FROM jsonTable"), + sql("SELECT b FROM jt").collect()) caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { val message = intercept[DDLException]{ - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable |USING json @@ -175,7 +176,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with test("a CTAS statement with column definitions is not allowed") { intercept[DDLException]{ - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) |USING json @@ -188,7 +189,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("it is not allowed to write to a table while querying it.") { - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING json @@ -199,7 +200,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with """.stripMargin) val message = intercept[AnalysisException] { - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTable |USING json diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 59cdb3fd6cca..5f8514e1a241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -70,10 +70,11 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } class DDLTestSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ override def beforeAll(): Unit = { super.beforeAll() - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE ddlPeople |USING org.apache.spark.sql.sources.DDLScanSource @@ -107,7 +108,7 @@ class DDLTestSuite extends DataSourceTest with SharedSQLContext { )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = caseInsensitiveContext.sql("describe ddlPeople") + val attributes = sql("describe ddlPeople") .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 12ba1ec6accd..c81c3d398280 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -98,10 +98,11 @@ object FiltersPushed { } class FilteredScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ override def beforeAll(): Unit = { super.beforeAll() - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE oneToTenFiltered |USING org.apache.spark.sql.sources.FilteredScanSource @@ -238,7 +239,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { - val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution + val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 925f9647d7c5..78bd3e558296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null @@ -32,7 +33,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") - caseInsensitiveContext.sql( + sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) |USING org.apache.spark.sql.json.DefaultSource @@ -53,42 +54,42 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } test("Simple INSERT OVERWRITE a JSONRelation") { - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) } test("PreInsert casting and renaming") { - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, a * 4 FROM jt """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 2, s"${i * 4}")) ) - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a * 4 AS A, a * 6 as c FROM jt """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 4, s"${i * 6}")) ) } test("SELECT clause generating a different number of columns is not allowed.") { val message = intercept[RuntimeException] { - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt """.stripMargin) @@ -100,45 +101,45 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } test("INSERT OVERWRITE a JSONRelation multiple times") { - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 5) caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 10) caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a * 10, b FROM jt1 """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 10, s"str$i")) ) @@ -147,22 +148,22 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } test("INSERT INTO JSONRelation for now") { - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), - caseInsensitiveContext.sql("SELECT a, b FROM jt").collect() + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt").collect() ) - caseInsensitiveContext.sql( + sql( s""" |INSERT INTO TABLE jsonTable SELECT a, b FROM jt """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), - caseInsensitiveContext.sql("SELECT a, b FROM jt UNION ALL SELECT a, b FROM jt").collect() + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt UNION ALL SELECT a, b FROM jt").collect() ) } @@ -170,20 +171,20 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") .write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 5, s"str$i")) ) caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( - caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) } test("it is not allowed to write to a table while querying it.") { val message = intercept[AnalysisException] { - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jsonTable """.stripMargin) @@ -195,58 +196,58 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { test("Caching") { // write something to the jsonTable - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt """.stripMargin) // Cached Query Execution caseInsensitiveContext.cacheTable("jsonTable") - assertCached(caseInsensitiveContext.sql("SELECT * FROM jsonTable")) + assertCached(sql("SELECT * FROM jsonTable")) checkAnswer( - caseInsensitiveContext.sql("SELECT * FROM jsonTable"), + sql("SELECT * FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i"))) - assertCached(caseInsensitiveContext.sql("SELECT a FROM jsonTable")) + assertCached(sql("SELECT a FROM jsonTable")) checkAnswer( - caseInsensitiveContext.sql("SELECT a FROM jsonTable"), + sql("SELECT a FROM jsonTable"), (1 to 10).map(Row(_)).toSeq) - assertCached(caseInsensitiveContext.sql("SELECT a FROM jsonTable WHERE a < 5")) + assertCached(sql("SELECT a FROM jsonTable WHERE a < 5")) checkAnswer( - caseInsensitiveContext.sql("SELECT a FROM jsonTable WHERE a < 5"), + sql("SELECT a FROM jsonTable WHERE a < 5"), (1 to 4).map(Row(_)).toSeq) - assertCached(caseInsensitiveContext.sql("SELECT a * 2 FROM jsonTable")) + assertCached(sql("SELECT a * 2 FROM jsonTable")) checkAnswer( - caseInsensitiveContext.sql("SELECT a * 2 FROM jsonTable"), + sql("SELECT a * 2 FROM jsonTable"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(caseInsensitiveContext.sql( + assertCached(sql( "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) - checkAnswer(caseInsensitiveContext.sql( + checkAnswer(sql( "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Insert overwrite and keep the same schema. - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt """.stripMargin) // jsonTable should be recached. - assertCached(caseInsensitiveContext.sql("SELECT * FROM jsonTable")) + assertCached(sql("SELECT * FROM jsonTable")) // TODO we need to invalidate the cached data in InsertIntoHadoopFsRelation // // The cached data is the new data. // checkAnswer( -// caseInsensitiveContext.sql("SELECT a, b FROM jsonTable"), -// caseInsensitiveContext.sql("SELECT a * 2, b FROM jt").collect()) +// sql("SELECT a, b FROM jsonTable"), +// sql("SELECT a * 2, b FROM jt").collect()) // // // Verify uncaching // caseInsensitiveContext.uncacheTable("jsonTable") -// assertCached(caseInsensitiveContext.sql("SELECT * FROM jsonTable"), 0) +// assertCached(sql("SELECT * FROM jsonTable"), 0) } test("it's not allowed to insert into a relation that is not an InsertableRelation") { - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE oneToTen |USING org.apache.spark.sql.sources.SimpleScanSource @@ -257,12 +258,12 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT * FROM oneToTen"), + sql("SELECT * FROM oneToTen"), (1 to 10).map(Row(_)).toSeq ) val message = intercept[AnalysisException] { - caseInsensitiveContext.sql( + sql( s""" |INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index c5dd8aae07b5..a89c5f8007e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -53,10 +53,11 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } class PrunedScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ override def beforeAll(): Unit = { super.beforeAll() - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -116,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest with SharedSQLContext { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution + val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 8c463ba8802a..f18546b4c2d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext private var originalDefaultSource: String = null private var path: File = null @@ -70,7 +71,7 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), - caseInsensitiveContext.sql(s"SELECT b FROM $tbl").collect()) + sql(s"SELECT b FROM $tbl").collect()) } test("save with path and load") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 853273679864..12af8068c398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -97,6 +97,8 @@ case class AllDataTypesScan( } class TableScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ + private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( s"str_$i", @@ -123,7 +125,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { override def beforeAll(): Unit = { super.beforeAll() - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE oneToTen |USING org.apache.spark.sql.sources.SimpleScanSource @@ -135,7 +137,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { |) """.stripMargin) - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE tableWithSchema ( |`string$%Field` stRIng, @@ -230,7 +232,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { assert(expectedSchema == caseInsensitiveContext.table("tableWithSchema").schema) checkAnswer( - caseInsensitiveContext.sql( + sql( """SELECT | `string$%Field`, | cast(binaryField as string), @@ -283,39 +285,39 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { test("Caching") { // Cached Query Execution caseInsensitiveContext.cacheTable("oneToTen") - assertCached(caseInsensitiveContext.sql("SELECT * FROM oneToTen")) + assertCached(sql("SELECT * FROM oneToTen")) checkAnswer( - caseInsensitiveContext.sql("SELECT * FROM oneToTen"), + sql("SELECT * FROM oneToTen"), (1 to 10).map(Row(_)).toSeq) - assertCached(caseInsensitiveContext.sql("SELECT i FROM oneToTen")) + assertCached(sql("SELECT i FROM oneToTen")) checkAnswer( - caseInsensitiveContext.sql("SELECT i FROM oneToTen"), + sql("SELECT i FROM oneToTen"), (1 to 10).map(Row(_)).toSeq) - assertCached(caseInsensitiveContext.sql("SELECT i FROM oneToTen WHERE i < 5")) + assertCached(sql("SELECT i FROM oneToTen WHERE i < 5")) checkAnswer( - caseInsensitiveContext.sql("SELECT i FROM oneToTen WHERE i < 5"), + sql("SELECT i FROM oneToTen WHERE i < 5"), (1 to 4).map(Row(_)).toSeq) - assertCached(caseInsensitiveContext.sql("SELECT i * 2 FROM oneToTen")) + assertCached(sql("SELECT i * 2 FROM oneToTen")) checkAnswer( - caseInsensitiveContext.sql("SELECT i * 2 FROM oneToTen"), + sql("SELECT i * 2 FROM oneToTen"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(caseInsensitiveContext.sql( + assertCached(sql( "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) - checkAnswer(caseInsensitiveContext.sql( + checkAnswer(sql( "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching caseInsensitiveContext.uncacheTable("oneToTen") - assertCached(caseInsensitiveContext.sql("SELECT * FROM oneToTen"), 0) + assertCached(sql("SELECT * FROM oneToTen"), 0) } test("defaultSource") { - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE oneToTenDef |USING org.apache.spark.sql.sources @@ -326,7 +328,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { """.stripMargin) checkAnswer( - caseInsensitiveContext.sql("SELECT * FROM oneToTenDef"), + sql("SELECT * FROM oneToTenDef"), (1 to 10).map(Row(_)).toSeq) } @@ -334,7 +336,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { // Make sure we do throw correct exception when users use a relation provider that // only implements the RelationProvier or the SchemaRelationProvider. val schemaNotAllowed = intercept[Exception] { - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE relationProvierWithSchema (i int) |USING org.apache.spark.sql.sources.SimpleScanSource @@ -347,7 +349,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { assert(schemaNotAllowed.getMessage.contains("does not allow user-specified schemas")) val schemaNeeded = intercept[Exception] { - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE schemaRelationProvierWithoutSchema |USING org.apache.spark.sql.sources.AllDataTypesScanSource @@ -361,7 +363,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { } test("SPARK-5196 schema field with comment") { - caseInsensitiveContext.sql( + sql( """ |CREATE TEMPORARY TABLE student(name string comment "SN", age int comment "SA", grade int) |USING org.apache.spark.sql.sources.AllDataTypesScanSource @@ -373,7 +375,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { |) """.stripMargin) - val planned = caseInsensitiveContext.sql("SELECT * FROM student").queryExecution.executedPlan + val planned = sql("SELECT * FROM student").queryExecution.executedPlan val comments = planned.schema.fields.map { field => if (field.metadata.contains("comment")) field.metadata.getString("comment") else "NO_COMMENT" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 85d73d8e78a2..1374a97476ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -36,14 +36,14 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. - lazy val testData: DataFrame = { + protected lazy val testData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("testData") df } - lazy val testData2: DataFrame = { + protected lazy val testData2: DataFrame = { val df = _sqlContext.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: @@ -55,7 +55,7 @@ private[sql] trait SQLTestData { self => df } - lazy val testData3: DataFrame = { + protected lazy val testData3: DataFrame = { val df = _sqlContext.sparkContext.parallelize( TestData3(1, None) :: TestData3(2, Some(2)) :: Nil).toDF() @@ -63,14 +63,14 @@ private[sql] trait SQLTestData { self => df } - lazy val negativeData: DataFrame = { + protected lazy val negativeData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() df.registerTempTable("negativeData") df } - lazy val largeAndSmallInts: DataFrame = { + protected lazy val largeAndSmallInts: DataFrame = { val df = _sqlContext.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: @@ -82,7 +82,7 @@ private[sql] trait SQLTestData { self => df } - lazy val decimalData: DataFrame = { + protected lazy val decimalData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, 2) :: @@ -94,7 +94,7 @@ private[sql] trait SQLTestData { self => df } - lazy val binaryData: DataFrame = { + protected lazy val binaryData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( BinaryData("12".getBytes, 1) :: BinaryData("22".getBytes, 5) :: @@ -105,7 +105,7 @@ private[sql] trait SQLTestData { self => df } - lazy val upperCaseData: DataFrame = { + protected lazy val upperCaseData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( UpperCaseData(1, "A") :: UpperCaseData(2, "B") :: @@ -117,7 +117,7 @@ private[sql] trait SQLTestData { self => df } - lazy val lowerCaseData: DataFrame = { + protected lazy val lowerCaseData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: @@ -127,7 +127,7 @@ private[sql] trait SQLTestData { self => df } - lazy val arrayData: RDD[ArrayData] = { + protected lazy val arrayData: RDD[ArrayData] = { val rdd = _sqlContext.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) @@ -135,7 +135,7 @@ private[sql] trait SQLTestData { self => rdd } - lazy val mapData: RDD[MapData] = { + protected lazy val mapData: RDD[MapData] = { val rdd = _sqlContext.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: @@ -146,13 +146,13 @@ private[sql] trait SQLTestData { self => rdd } - lazy val repeatedData: RDD[StringData] = { + protected lazy val repeatedData: RDD[StringData] = { val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("repeatedData") rdd } - lazy val nullableRepeatedData: RDD[StringData] = { + protected lazy val nullableRepeatedData: RDD[StringData] = { val rdd = _sqlContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) @@ -160,7 +160,7 @@ private[sql] trait SQLTestData { self => rdd } - lazy val nullInts: DataFrame = { + protected lazy val nullInts: DataFrame = { val df = _sqlContext.sparkContext.parallelize( NullInts(1) :: NullInts(2) :: @@ -170,7 +170,7 @@ private[sql] trait SQLTestData { self => df } - lazy val allNulls: DataFrame = { + protected lazy val allNulls: DataFrame = { val df = _sqlContext.sparkContext.parallelize( NullInts(null) :: NullInts(null) :: @@ -180,7 +180,7 @@ private[sql] trait SQLTestData { self => df } - lazy val nullStrings: DataFrame = { + protected lazy val nullStrings: DataFrame = { val df = _sqlContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: @@ -189,13 +189,13 @@ private[sql] trait SQLTestData { self => df } - lazy val tableName: DataFrame = { + protected lazy val tableName: DataFrame = { val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() df.registerTempTable("tableName") df } - lazy val unparsedStrings: RDD[String] = { + protected lazy val unparsedStrings: RDD[String] = { _sqlContext.sparkContext.parallelize( "1, A1, true, null" :: "2, B2, false, null" :: @@ -204,13 +204,13 @@ private[sql] trait SQLTestData { self => } // An RDD with 4 elements and 8 partitions - lazy val withEmptyParts: RDD[IntField] = { + protected lazy val withEmptyParts: RDD[IntField] = { val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) rdd.toDF().registerTempTable("withEmptyParts") rdd } - lazy val person: DataFrame = { + protected lazy val person: DataFrame = { val df = _sqlContext.sparkContext.parallelize( Person(0, "mike", 30) :: Person(1, "jim", 20) :: Nil).toDF() @@ -218,7 +218,7 @@ private[sql] trait SQLTestData { self => df } - lazy val salary: DataFrame = { + protected lazy val salary: DataFrame = { val df = _sqlContext.sparkContext.parallelize( Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil).toDF() @@ -226,7 +226,7 @@ private[sql] trait SQLTestData { self => df } - lazy val complexData: DataFrame = { + protected lazy val complexData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 8c1d9c180e7e..cdd691e03589 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -51,6 +51,9 @@ private[sql] trait SQLTestUtils // Whether to materialize all test data before the first test is run private var loadTestDataBeforeTests = false + // Shorthand for running a query using our SQLContext + protected lazy val sql = _sqlContext.sql _ + /** * A helper object for importing SQL implicits. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 59e65ff97b8e..46770de2a1cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.sources.DataSourceTest import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.{Row, SaveMode, SQLContext} import org.apache.spark.{Logging, SparkFunSuite} @@ -53,7 +53,7 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { } class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive private val testDF = range(1, 3).select( ('id + 0.1) cast DecimalType(10, 3) as 'd1, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 9b3ede43ee2d..311a33711fc7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -30,8 +30,8 @@ class UDFSuite extends QueryTest { ctx.udf.register("random0", () => { Math.random() }) ctx.udf.register("RANDOM1", () => { Math.random() }) ctx.udf.register("strlenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } From 0b60325eeb1e27626beb3c1c73b1c604be24cbfd Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 13 Aug 2015 15:02:42 -0700 Subject: [PATCH 39/39] Fix hive test compile --- .../spark/sql/hive/HiveMetastoreCatalogSuite.scala | 1 + .../org/apache/spark/sql/hive/HiveParquetSuite.scala | 11 +++++------ .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 4 ++-- .../apache/spark/sql/hive/MultiDatabaseSuite.scala | 5 ++--- .../sql/hive/ParquetHiveCompatibilitySuite.scala | 3 ++- .../scala/org/apache/spark/sql/hive/UDFSuite.scala | 10 ++++------ .../sql/hive/execution/AggregationQuerySuite.scala | 9 +++++---- .../spark/sql/hive/execution/HiveExplainSuite.scala | 4 ++-- .../spark/sql/hive/execution/SQLQuerySuite.scala | 3 ++- .../hive/execution/ScriptTransformationSuite.scala | 3 ++- .../scala/org/apache/spark/sql/hive/orc/OrcTest.scala | 4 ++-- .../org/apache/spark/sql/hive/parquetSuites.scala | 3 ++- .../sql/sources/CommitFailureTestRelationSuite.scala | 6 ++++-- .../spark/sql/sources/hadoopFsRelationSuites.scala | 5 ++--- 14 files changed, 37 insertions(+), 34 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 46770de2a1cb..574624d501f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -54,6 +54,7 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { override def _sqlContext: SQLContext = TestHive + import testImplicits._ private val testDF = range(1, 3).select( ('id + 0.1) cast DecimalType(10, 3) as 'd1, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 1fa005d5f9a1..fe0db5228de1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.execution.datasources.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{QueryTest, Row, SQLContext} case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive - - import sqlContext._ + private val ctx = TestHive + override def _sqlContext: SQLContext = ctx test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { @@ -54,7 +53,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).registerTempTable("p") + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -67,7 +66,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - read.parquet(file.getCanonicalPath).registerTempTable("p") + ctx.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 7f36a483a396..20a50586d520 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -22,7 +22,6 @@ import java.io.{IOException, File} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.InvalidInputException import org.scalatest.BeforeAndAfterAll import org.apache.spark.Logging @@ -42,7 +41,8 @@ import org.apache.spark.util.Utils */ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll with Logging { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext var jsonFilePath: String = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 73852f13ad20..417e8b07917c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -22,9 +22,8 @@ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val sqlContext: SQLContext = TestHive - - import sqlContext.sql + override val _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext private val df = sqlContext.range(10).coalesce(1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 251e0324bfa5..13452e71a1b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.{Row, SQLConf, SQLContext} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { import ParquetCompatibilityTest.makeNullable - override val sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext /** * Set the staging directory (and hence path to ignore Parquet files under) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 311a33711fc7..7ee1c8d13aa3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,21 +17,19 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.QueryTest case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) ctx.udf.register("RANDOM1", () => { Math.random() }) ctx.udf.register("strlenScala", (_: String).length + (_: Int)) - assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7b5aa4763fd9..a312f8495824 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql._ -import org.scalatest.BeforeAndAfterAll import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ var originalUseAggregate2: Boolean = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 44c5b80392fa..11d7a872dff0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.test.SQLTestUtils * A set of tests that validates support for Hive Explain command. */ class HiveExplainSuite extends QueryTest with SQLTestUtils { - - def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 79a136ae6f61..8b8f520776e7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -66,7 +66,8 @@ class MyDialect extends DefaultParserDialect * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest with SQLTestUtils { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext test("UDTF") { sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 0875232aede3..9aca40f15ac1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StringType class ScriptTransformationSuite extends SparkPlanTest { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 145965388da0..f7ba20ff41d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => - lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive - + protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ import sqlContext.sparkContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 50f02432dacc..34d3434569f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -685,7 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest { * A collection of tests for parquet data with various forms of partitioning. */ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext var partitionedTableDir: File = null var normalTableDir: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index e976125b3706..b4640b161628 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 2a69d331b6e5..af445626fbe4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -34,9 +34,8 @@ import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override lazy val sqlContext: SQLContext = TestHive - - import sqlContext.sql + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ val dataSourceName: String