diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 2b393f30d143..92379539290d 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1524,6 +1524,14 @@ that these options will be deprecated in future release as more optimizations ar
Configures the number of partitions to use when shuffling data for joins or aggregations.
+
+ spark.sql.typeCoercion.mode |
+ default |
+
+ Since Spark 2.4, the hive mode is introduced for Hive compatiblity.
+ Spark SQL has its native type cocersion mode, which is enabled by default.
+ |
+
## Broadcast Hint for SQL Queries
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index ec7e7761dc4c..9e2754442760 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -122,12 +122,21 @@ object TypeCoercion {
case _ => None
}
+ private def findCommonTypeForBinaryComparison(
+ dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = {
+ if (conf.isHiveTypeCoercionMode) {
+ findHiveCommonTypeForBinary(dt1, dt2)
+ } else {
+ findNativeCommonTypeForBinary(dt1, dt2, conf)
+ }
+ }
+
/**
* This function determines the target type of a comparison operator when one operand
* is a String and the other is not. It also handles when one op is a Date and the
* other is a Timestamp by making the target type to be String.
*/
- private def findCommonTypeForBinaryComparison(
+ private def findNativeCommonTypeForBinary(
dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match {
// We should cast all relative timestamp/date/string comparison into string comparisons
// This behaves as a user would expect because timestamp strings sort lexicographically.
@@ -158,6 +167,28 @@ object TypeCoercion {
case (l, r) => None
}
+ /**
+ * This function follow hive's binary comparison action:
+ * https://github.com/apache/hive/blob/rel/release-3.0.0/ql/src/java/
+ * org/apache/hadoop/hive/ql/exec/FunctionRegistry.java#L802
+ */
+ private def findHiveCommonTypeForBinary(
+ dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match {
+ case (StringType, DateType) => Some(DateType)
+ case (DateType, StringType) => Some(DateType)
+ case (StringType, TimestampType) => Some(TimestampType)
+ case (TimestampType, StringType) => Some(TimestampType)
+ case (TimestampType, DateType) => Some(TimestampType)
+ case (DateType, TimestampType) => Some(TimestampType)
+ case (StringType, NullType) => Some(StringType)
+ case (NullType, StringType) => Some(StringType)
+ case (StringType | TimestampType, r: NumericType) => Some(DoubleType)
+ case (l: NumericType, StringType | TimestampType) => Some(DoubleType)
+ case (l: StringType, r: AtomicType) if r != StringType => Some(r)
+ case (l: AtomicType, r: StringType) if l != StringType => Some(l)
+ case _ => None
+ }
+
/**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 13f31a6b2eb9..bfb06ccfbf85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1116,6 +1116,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val typeCoercionMode =
+ buildConf("spark.sql.typeCoercion.mode")
+ .doc("Since Spark 2.4, the 'hive' mode is introduced for Hive compatiblity. " +
+ "Spark SQL has its native type cocersion mode, which is enabled by default.")
+ .stringConf
+ .transform(_.toLowerCase(Locale.ROOT))
+ .checkValues(Set("default", "hive"))
+ .createWithDefault("default")
+
val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter")
.internal()
.doc("When true, the apply function of the rule verifies whether the right node of the" +
@@ -1563,6 +1572,8 @@ class SQLConf extends Serializable with Logging {
def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE)
+ def isHiveTypeCoercionMode: Boolean = getConf(SQLConf.typeCoercionMode).equals("hive")
+
def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TypeCoercionModeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TypeCoercionModeSuite.scala
new file mode 100644
index 000000000000..eff87ee1f803
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TypeCoercionModeSuite.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.sql.{Date, Timestamp}
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.sql._
+import org.apache.spark.sql.internal.SQLConf
+
+class TypeCoercionModeSuite extends SparkFunSuite with BeforeAndAfterAll {
+
+ private var originalActiveSparkSession: Option[SparkSession] = _
+ private var originalInstantiatedSparkSession: Option[SparkSession] = _
+
+ override protected def beforeAll(): Unit = {
+ originalActiveSparkSession = SparkSession.getActiveSession
+ originalInstantiatedSparkSession = SparkSession.getDefaultSession
+
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ }
+
+ override protected def afterAll(): Unit = {
+ originalActiveSparkSession.foreach(ctx => SparkSession.setActiveSession(ctx))
+ originalInstantiatedSparkSession.foreach(ctx => SparkSession.setDefaultSession(ctx))
+ }
+
+ private def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ QueryTest.checkAnswer(actual, expectedAnswer) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ private var sparkSession: SparkSession = _
+
+ private def withTypeCoercionMode[T](typeCoercionMode: String)(f: SparkSession => T): T = {
+ try {
+ val sparkConf = new SparkConf(false)
+ .setMaster("local")
+ .setAppName(this.getClass.getName)
+ .set("spark.ui.enabled", "false")
+ .set("spark.driver.allowMultipleContexts", "true")
+ .set(SQLConf.typeCoercionMode.key, typeCoercionMode)
+
+ sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+ f(sparkSession)
+ } finally {
+ if (sparkSession != null) {
+ sparkSession.sql("DROP TABLE IF EXISTS v")
+ sparkSession.stop()
+ }
+ }
+ }
+
+ test("CommonTypeForBinaryComparison: StringType vs NumericType") {
+ val str1 = Long.MaxValue.toString + "1"
+ val str2 = Int.MaxValue.toString + "1"
+ val str3 = "10"
+ val str4 = "0"
+ val str5 = "-0.4"
+ val str6 = "0.6"
+
+ val data = Seq(str1, str2, str3, str4, str5, str6)
+
+ val q1 = "SELECT c1 FROM v WHERE c1 > 0"
+ val q2 = "SELECT c1 FROM v WHERE c1 > 0L"
+ val q3 = "SELECT c1 FROM v WHERE c1 = 0"
+ val q4 = "SELECT c1 FROM v WHERE c1 in (0)"
+
+ withTypeCoercionMode("hive") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createOrReplaceTempView("v")
+ checkAnswer(spark.sql(q1), Row(str1) :: Row(str2) :: Row(str3) :: Row(str6) :: Nil)
+ checkAnswer(spark.sql(q2), Row(str1) :: Row(str2) :: Row(str3) :: Row(str6) :: Nil)
+ checkAnswer(spark.sql(q3), Row(str4) :: Nil)
+ checkAnswer(spark.sql(q4), Row(str4) :: Nil)
+ }
+
+ withTypeCoercionMode("default") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createOrReplaceTempView("v")
+ checkAnswer(spark.sql(q1), Row(str3) :: Nil)
+ checkAnswer(spark.sql(q2), Row(str2) :: Row(str3) :: Nil)
+ checkAnswer(spark.sql(q3), Row(str4) :: Row(str5) :: Row(str6) :: Nil)
+ checkAnswer(spark.sql(q4), Row(str4) :: Nil)
+ }
+ }
+
+ test("CommonTypeForBinaryComparison: StringType vs DateType") {
+ val v1 = Date.valueOf("2017-09-22")
+ val v2 = Date.valueOf("2017-09-09")
+
+ val data = Seq(v1, v2)
+
+ val q1 = "SELECT c1 FROM v WHERE c1 > '2017-8-1'"
+ val q2 = "SELECT c1 FROM v WHERE c1 > '2014'"
+ val q3 = "SELECT c1 FROM v WHERE c1 > cast('2017-8-1' as date)"
+
+ withTypeCoercionMode("hive") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createTempView("v")
+ checkAnswer(spark.sql(q1), Row(v1) :: Row(v2) :: Nil)
+ checkAnswer(spark.sql(q2), Row(v1) :: Row(v2) :: Nil)
+ checkAnswer(spark.sql(q3), Row(v1) :: Row(v2) :: Nil)
+ }
+
+ withTypeCoercionMode("default") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createTempView("v")
+ checkAnswer(spark.sql(q1), Nil)
+ checkAnswer(spark.sql(q2), Row(v1) :: Row(v2) :: Nil)
+ checkAnswer(spark.sql(q3), Row(v1) :: Row(v2) :: Nil)
+ }
+ }
+
+ test("CommonTypeForBinaryComparison: StringType vs TimestampType") {
+ val v1 = Timestamp.valueOf("2017-07-21 23:42:12.123")
+ val v2 = Timestamp.valueOf("2017-08-21 23:42:12.123")
+ val v3 = Timestamp.valueOf("2017-08-21 23:42:12")
+
+ val data = Seq(v1, v2, v3)
+
+ val q1 = "SELECT c1 FROM v WHERE c1 > '2017-8-1'"
+ val q2 = "SELECT c1 FROM v WHERE c1 < '2017-08-21 23:42:12.0'"
+ val q3 = "SELECT c1 FROM v WHERE c1 > cast('2017-8-1' as timestamp)"
+
+ withTypeCoercionMode("hive") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createTempView("v")
+ checkAnswer(spark.sql(q1), Row(v2) :: Row(v3) :: Nil)
+ checkAnswer(spark.sql(q2), Row(v1) :: Nil)
+ checkAnswer(spark.sql(q3), Row(v2) :: Row(v3) :: Nil)
+ }
+
+ withTypeCoercionMode("default") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createTempView("v")
+ checkAnswer(spark.sql(q1), Nil)
+ checkAnswer(spark.sql(q2), Row(v1) :: Row(v3) :: Nil)
+ checkAnswer(spark.sql(q3), Row(v2) :: Row(v3) :: Nil)
+ }
+ }
+
+ test("CommonTypeForBinaryComparison: TimestampType vs DateType") {
+ val v1 = Timestamp.valueOf("2017-07-21 23:42:12.123")
+ val v2 = Timestamp.valueOf("2017-08-21 23:42:12.123")
+
+ val data = Seq(v1, v2)
+
+ val q1 = "SELECT c1 FROM v WHERE c1 > cast('2017-8-1' as date)"
+ val q2 = "SELECT c1 FROM v WHERE c1 > cast('2017-8-1' as timestamp)"
+
+ withTypeCoercionMode("hive") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createTempView("v")
+ checkAnswer(spark.sql(q1), Row(v2) :: Nil)
+ checkAnswer(spark.sql(q2), Row(v2) :: Nil)
+ }
+
+ withTypeCoercionMode("default") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createTempView("v")
+ checkAnswer(spark.sql(q1), Row(v2) :: Nil)
+ checkAnswer(spark.sql(q2), Row(v2) :: Nil)
+ }
+ }
+
+ test("CommonTypeForBinaryComparison: TimestampType vs NumericType") {
+ val v1 = Timestamp.valueOf("2017-07-21 23:42:12.123")
+ val v2 = Timestamp.valueOf("2017-08-21 23:42:12.123")
+
+ val data = Seq(v1, v2)
+
+ val q1 = "SELECT c1 FROM v WHERE c1 > 1"
+ val q2 = "SELECT c1 FROM v WHERE c1 > '2017-8-1'"
+ val q3 = "SELECT c1 FROM v WHERE c1 > '2017-08-01'"
+ val q4 = "SELECT c1 FROM v WHERE c1 > cast(cast('2017-08-01' as timestamp) as double)"
+
+ withTypeCoercionMode("hive") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createTempView("v")
+ checkAnswer(spark.sql(q1), Row(v1) :: Row(v2) :: Nil)
+ checkAnswer(spark.sql(q2), Row(v2) :: Nil)
+ checkAnswer(spark.sql(q3), Row(v2) :: Nil)
+ checkAnswer(spark.sql(q4), Row(v2) :: Nil)
+ }
+
+ withTypeCoercionMode("default") { spark =>
+ import spark.implicits._
+ data.toDF("c1").createTempView("v")
+ val e1 = intercept[AnalysisException] {
+ spark.sql(q1)
+ }
+ assert(e1.getMessage.contains("data type mismatch"))
+ checkAnswer(spark.sql(q2), Nil)
+ checkAnswer(spark.sql(q3), Row(v2) :: Nil)
+ val e2 = intercept[AnalysisException] {
+ spark.sql(q4)
+ }
+ assert(e2.getMessage.contains("data type mismatch"))
+ }
+ }
+}