Skip to content

Commit ebc8c61

Browse files
committed
use SQLTestUtils and If
1 parent 625973c commit ebc8c61

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,9 @@ trait HiveTypeCoercion {
497497
}
498498

499499
private def transform(booleanExpr: Expression, numericExpr: Expression) = {
500-
CaseWhen(Seq(
501-
Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal.create(null, BooleanType),
502-
buildCaseKeyWhen(booleanExpr, numericExpr)
503-
))
500+
If(Or(IsNull(booleanExpr), IsNull(numericExpr)),
501+
Literal.create(null, BooleanType),
502+
buildCaseKeyWhen(booleanExpr, numericExpr))
504503
}
505504

506505
private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = {

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,20 @@ import org.apache.spark.sql.catalyst.errors.DialectException
2424
import org.apache.spark.sql.execution.GeneratedAggregate
2525
import org.apache.spark.sql.functions._
2626
import org.apache.spark.sql.TestData._
27-
import org.apache.spark.sql.test.TestSQLContext
27+
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
2828
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
2929

3030
import org.apache.spark.sql.types._
3131

3232
/** A SQL Dialect for testing purpose, and it can not be nested type */
3333
class MyDialect extends DefaultParserDialect
3434

35-
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
35+
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
3636
// Make sure the tables are loaded.
3737
TestData
3838

39-
import org.apache.spark.sql.test.TestSQLContext.implicits._
40-
val sqlCtx = TestSQLContext
39+
val sqlContext = TestSQLContext
40+
import sqlContext.implicits._
4141

4242
test("SPARK-6743: no columns from cache") {
4343
Seq(
@@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
915915
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
916916
}
917917

918-
val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
918+
val df1 = createDataFrame(rowRDD1, schema1)
919919
df1.registerTempTable("applySchema1")
920920
checkAnswer(
921921
sql("SELECT * FROM applySchema1"),
@@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
945945
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
946946
}
947947

948-
val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
948+
val df2 = createDataFrame(rowRDD2, schema2)
949949
df2.registerTempTable("applySchema2")
950950
checkAnswer(
951951
sql("SELECT * FROM applySchema2"),
@@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
970970
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
971971
}
972972

973-
val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
973+
val df3 = createDataFrame(rowRDD3, schema2)
974974
df3.registerTempTable("applySchema3")
975975

976976
checkAnswer(
@@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
10151015
.build()
10161016
val schemaWithMeta = new StructType(Array(
10171017
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
1018-
val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
1018+
val personWithMeta = createDataFrame(person.rdd, schemaWithMeta)
10191019
def validateMetadata(rdd: DataFrame): Unit = {
10201020
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
10211021
}
@@ -1333,23 +1333,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
13331333
}
13341334

13351335
test("SPARK-7952: fix the equality check between boolean and numeric types") {
1336-
val df = Seq(
1337-
(1, true),
1338-
(0, false),
1339-
(2, true),
1340-
(2, false),
1341-
(null, true),
1342-
(null, false),
1343-
(0, null),
1344-
(1, null),
1345-
(null, null)
1346-
).map { case (i, b) =>
1347-
(i.asInstanceOf[Integer], b.asInstanceOf[java.lang.Boolean])
1348-
}.toDF("i", "b")
1349-
1350-
checkAnswer(df.select('i === 'b),
1351-
Seq(true, true, false, false, null, null, null, null, null).map(Row(_)))
1352-
checkAnswer(df.select('i <=> 'b),
1353-
Seq(true, true, false, false, false, false, false, false, true).map(Row(_)))
1336+
withTempTable("t") {
1337+
Seq(
1338+
(1, true),
1339+
(0, false),
1340+
(2, true),
1341+
(2, false),
1342+
(null, true),
1343+
(null, false),
1344+
(0, null),
1345+
(1, null),
1346+
(null, null)
1347+
).map { case (i, b) =>
1348+
(i.asInstanceOf[Integer], b.asInstanceOf[java.lang.Boolean])
1349+
}.toDF("i", "b").registerTempTable("t")
1350+
1351+
checkAnswer(sql("select i = b from t"),
1352+
Seq(true, true, false, false, null, null, null, null, null).map(Row(_)))
1353+
checkAnswer(sql("select i <=> b from t"),
1354+
Seq(true, true, false, false, false, false, false, false, true).map(Row(_)))
1355+
}
13541356
}
13551357
}

0 commit comments

Comments
 (0)