@@ -24,20 +24,20 @@ import org.apache.spark.sql.catalyst.errors.DialectException
2424import org .apache .spark .sql .execution .GeneratedAggregate
2525import org .apache .spark .sql .functions ._
2626import org .apache .spark .sql .TestData ._
27- import org .apache .spark .sql .test .TestSQLContext
27+ import org .apache .spark .sql .test .{ SQLTestUtils , TestSQLContext }
2828import org .apache .spark .sql .test .TestSQLContext .{udf => _ , _ }
2929
3030import org .apache .spark .sql .types ._
3131
3232/** A SQL Dialect for testing purpose, and it can not be nested type */
3333class MyDialect extends DefaultParserDialect
3434
35- class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
35+ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
3636 // Make sure the tables are loaded.
3737 TestData
3838
39- import org . apache . spark . sql . test . TestSQLContext . implicits . _
40- val sqlCtx = TestSQLContext
39+ val sqlContext = TestSQLContext
40+ import sqlContext . implicits . _
4141
4242 test(" SPARK-6743: no columns from cache" ) {
4343 Seq (
@@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
915915 Row (values(0 ).toInt, values(1 ), values(2 ).toBoolean, v4)
916916 }
917917
918- val df1 = sqlCtx. createDataFrame(rowRDD1, schema1)
918+ val df1 = createDataFrame(rowRDD1, schema1)
919919 df1.registerTempTable(" applySchema1" )
920920 checkAnswer(
921921 sql(" SELECT * FROM applySchema1" ),
@@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
945945 Row (Row (values(0 ).toInt, values(2 ).toBoolean), Map (values(1 ) -> v4))
946946 }
947947
948- val df2 = sqlCtx. createDataFrame(rowRDD2, schema2)
948+ val df2 = createDataFrame(rowRDD2, schema2)
949949 df2.registerTempTable(" applySchema2" )
950950 checkAnswer(
951951 sql(" SELECT * FROM applySchema2" ),
@@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
970970 Row (Row (values(0 ).toInt, values(2 ).toBoolean), scala.collection.mutable.Map (values(1 ) -> v4))
971971 }
972972
973- val df3 = sqlCtx. createDataFrame(rowRDD3, schema2)
973+ val df3 = createDataFrame(rowRDD3, schema2)
974974 df3.registerTempTable(" applySchema3" )
975975
976976 checkAnswer(
@@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
10151015 .build()
10161016 val schemaWithMeta = new StructType (Array (
10171017 schema(" id" ), schema(" name" ).copy(metadata = metadata), schema(" age" )))
1018- val personWithMeta = sqlCtx. createDataFrame(person.rdd, schemaWithMeta)
1018+ val personWithMeta = createDataFrame(person.rdd, schemaWithMeta)
10191019 def validateMetadata (rdd : DataFrame ): Unit = {
10201020 assert(rdd.schema(" name" ).metadata.getString(docKey) == docValue)
10211021 }
@@ -1333,23 +1333,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
13331333 }
13341334
13351335 test(" SPARK-7952: fix the equality check between boolean and numeric types" ) {
1336- val df = Seq (
1337- (1 , true ),
1338- (0 , false ),
1339- (2 , true ),
1340- (2 , false ),
1341- (null , true ),
1342- (null , false ),
1343- (0 , null ),
1344- (1 , null ),
1345- (null , null )
1346- ).map { case (i, b) =>
1347- (i.asInstanceOf [Integer ], b.asInstanceOf [java.lang.Boolean ])
1348- }.toDF(" i" , " b" )
1349-
1350- checkAnswer(df.select(' i === ' b ),
1351- Seq (true , true , false , false , null , null , null , null , null ).map(Row (_)))
1352- checkAnswer(df.select(' i <=> ' b ),
1353- Seq (true , true , false , false , false , false , false , false , true ).map(Row (_)))
1336+ withTempTable(" t" ) {
1337+ Seq (
1338+ (1 , true ),
1339+ (0 , false ),
1340+ (2 , true ),
1341+ (2 , false ),
1342+ (null , true ),
1343+ (null , false ),
1344+ (0 , null ),
1345+ (1 , null ),
1346+ (null , null )
1347+ ).map { case (i, b) =>
1348+ (i.asInstanceOf [Integer ], b.asInstanceOf [java.lang.Boolean ])
1349+ }.toDF(" i" , " b" ).registerTempTable(" t" )
1350+
1351+ checkAnswer(sql(" select i = b from t" ),
1352+ Seq (true , true , false , false , null , null , null , null , null ).map(Row (_)))
1353+ checkAnswer(sql(" select i <=> b from t" ),
1354+ Seq (true , true , false , false , false , false , false , false , true ).map(Row (_)))
1355+ }
13541356 }
13551357}
0 commit comments