diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index b93f8b795666f..9ad7ad62117cc 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -221,7 +221,6 @@ This is a graphical depiction of the precedence list as a directed tree: The least common type from a set of types is the narrowest type reachable from the precedence list by all elements of the set of types. The least common type resolution is used to: -- Decide whether a function expecting a parameter of a type can be invoked using an argument of a narrower type. - Derive the argument type for functions which expect a shared argument type for multiple parameters, such as coalesce, least, or greatest. - Derive the operand types for operators such as arithmetic operations or comparisons. - Derive the result type for expressions such as the case expression. @@ -246,19 +245,40 @@ DOUBLE > SELECT (typeof(coalesce(1BD, 1F))); DOUBLE --- The substring function expects arguments of type INT for the start and length parameters. -> SELECT substring('hello', 1Y, 2); -he -> SELECT substring('hello', '1', 2); -he -> SELECT substring('hello', 1L, 2); -Error: Argument 2 requires an INT type. -> SELECT substring('hello', str, 2) FROM VALUES(CAST('1' AS STRING)) AS T(str); -Error: Argument 2 requires an INT type. ``` ### SQL Functions +#### Function invocation +Under ANSI mode(spark.sql.ansi.enabled=true), the function invocation of Spark SQL: +- In general, it follows the `Store assignment` rules as storing the input values as the declared parameter type of the SQL functions +- Special rules apply for string literals and untyped NULL. A NULL can be promoted to any other type, while a string literal can be promoted to any simple data type. +```sql +> SET spark.sql.ansi.enabled=true; +-- implicitly cast Int to String type +> SELECT concat('total number: ', 1); +total number: 1 +-- implicitly cast Timestamp to Date type +> select datediff(now(), current_date); +0 + +-- specialrule: implicitly cast String literal to Double type +> SELECT ceil('0.1'); +1 +-- specialrule: implicitly cast NULL to Date type +> SELECT year(null); +NULL + +> CREATE TABLE t(s string); +-- Can't store String column as Numeric types. +> SELECT ceil(s) from t; +Error in query: cannot resolve 'CEIL(spark_catalog.default.t.s)' due to data type mismatch +-- Can't store String column as Date type. +> select year(s) from t; +Error in query: cannot resolve 'year(spark_catalog.default.t.s)' due to data type mismatch +``` + +#### Functions with different behaviors The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - `size`: This function returns null for null input. - `element_at`: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index e8bf2aeac136b..debc13b953ee5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -159,6 +159,10 @@ object AnsiTypeCoercion extends TypeCoercionBase { // If the expected type equals the input type, no need to cast. case _ if expectedType.acceptsType(inType) => Some(inType) + // If input is a numeric type but not decimal, and we expect a decimal type, + // cast the input to decimal. + case (n: NumericType, DecimalType) => Some(DecimalType.forType(n)) + // Cast null type (usually from null literals) into target types // By default, the result type is `target.defaultConcreteType`. When the target type is // `TypeCollection`, there is another branch to find the "closet convertible data type" below. @@ -178,79 +182,17 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (StringType, DecimalType) if isInputFoldable => Some(DecimalType.SYSTEM_DEFAULT) - // If input is a numeric type but not decimal, and we expect a decimal type, - // cast the input to decimal. - case (d: NumericType, DecimalType) => Some(DecimalType.forType(d)) - - case (n1: NumericType, n2: NumericType) => - val widerType = findWiderTypeForTwo(n1, n2) - widerType match { - // if the expected type is Float type, we should still return Float type. - case Some(DoubleType) if n1 != DoubleType && n2 == FloatType => Some(FloatType) - - case Some(dt) if dt == n2 => Some(dt) - - case _ => None + case (_, target: DataType) => + if (Cast.canANSIStoreAssign(inType, target)) { + Some(target) + } else { + None } - case (DateType, TimestampType) => Some(TimestampType) - case (DateType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) - // When we reach here, input type is not acceptable for any types in this type collection, - // first try to find the all the expected types we can implicitly cast: - // 1. if there is no convertible data types, return None; - // 2. if there is only one convertible data type, cast input as it; - // 3. otherwise if there are multiple convertible data types, find the closet convertible - // data type among them. If there is no such a data type, return None. + // try to find the first one we can implicitly cast. case (_, TypeCollection(types)) => - // Since Spark contains special objects like `NumericType` and `DecimalType`, which accepts - // multiple types and they are `AbstractDataType` instead of `DataType`, here we use the - // conversion result their representation. - val convertibleTypes = types.flatMap(implicitCast(inType, _, isInputFoldable)) - if (convertibleTypes.isEmpty) { - None - } else { - // find the closet convertible data type, which can be implicit cast to all other - // convertible types. - val closestConvertibleType = convertibleTypes.find { dt => - convertibleTypes.forall { target => - implicitCast(dt, target, isInputFoldable = false).isDefined - } - } - // If the closet convertible type is Float type and the convertible types contains Double - // type, simply return Double type as the closet convertible type to avoid potential - // precision loss on converting the Integral type as Float type. - if (closestConvertibleType.contains(FloatType) && convertibleTypes.contains(DoubleType)) { - Some(DoubleType) - } else { - closestConvertibleType - } - } - - // Implicit cast between array types. - // - // Compare the nullabilities of the from type and the to type, check whether the cast of - // the nullability is resolvable by the following rules: - // 1. If the nullability of the to type is true, the cast is always allowed; - // 2. If the nullabilities of both the from type and the to type are false, the cast is - // allowed. - // 3. Otherwise, the cast is not allowed - case (ArrayType(fromType, containsNullFrom), ArrayType(toType: DataType, containsNullTo)) - if Cast.resolvableNullability(containsNullFrom, containsNullTo) => - implicitCast(fromType, toType, isInputFoldable).map(ArrayType(_, containsNullTo)) - - // Implicit cast between Map types. - // Follows the same semantics of implicit casting between two array types. - // Refer to documentation above. - case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn)) - if Cast.resolvableNullability(fn, tn) => - val newKeyType = implicitCast(fromKeyType, toKeyType, isInputFoldable) - val newValueType = implicitCast(fromValueType, toValueType, isInputFoldable) - if (newKeyType.isDefined && newValueType.isDefined) { - Some(MapType(newKeyType.get, newValueType.get, tn)) - } else { - None - } + types.flatMap(implicitCast(inType, _, isInputFoldable)).headOption case _ => None } @@ -348,6 +290,9 @@ object AnsiTypeCoercion extends TypeCoercionBase { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) => s.copy(left = Cast(s.left, s.right.dataType)) case s @ SubtractTimestamps(AnyTimestampType(), DateType(), _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 90cbe565fe6f3..506667461ec09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -1157,9 +1157,9 @@ object TypeCoercion extends TypeCoercionBase { override val transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case d @ DateAdd(TimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) case d @ DateAdd(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case d @ DateSub(TimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) case d @ DateSub(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 97925454888da..5ec303d97fbd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -162,12 +162,13 @@ object RuleIdCollection { // In the production code path, the following rules are run in CombinedTypeCoercionRule, and // hence we only need to add them for unit testing. "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$PromoteStringLiterals" :: + "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$DateTimeOperations" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$GetDateFieldOperations" :: "org.apache.spark.sql.catalyst.analysis.DecimalPrecision" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercion$BooleanEquality" :: + "org.apache.spark.sql.catalyst.analysis.TypeCoercion$DateTimeOperations" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CaseWhenCoercion" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$ConcatCoercion" :: - "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$DateTimeOperations" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$Division" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$EltCoercion" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$FunctionArgumentConversion" :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index ab8d9d9806921..809cbb2cebdbf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -19,42 +19,34 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp -import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils -class AnsiTypeCoercionSuite extends AnalysisTest { +class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { import TypeCoercionSuite._ - // When Utils.isTesting is true, RuleIdCollection adds individual type coercion rules. Otherwise, - // RuleIdCollection doesn't add them because they are called in a train inside - // CombinedTypeCoercionRule. - assert(Utils.isTesting, s"${IS_TESTING.key} is not set to true") - // scalastyle:off line.size.limit // The following table shows all implicit data type conversions that are not visible to the user. // +----------------------+----------+-----------+-------------+----------+------------+------------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ // | Source Type\CAST TO | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType | NumericType | IntegralType | // +----------------------+----------+-----------+-------------+----------+------------+------------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // | ByteType | ByteType | ShortType | IntegerType | LongType | DoubleType | DoubleType | Dec(10, 2) | X | X | X | X | X | X | X | X | X | X | DecimalType(3, 0) | ByteType | ByteType | - // | ShortType | X | ShortType | IntegerType | LongType | DoubleType | DoubleType | Dec(10, 2) | X | X | X | X | X | X | X | X | X | X | DecimalType(5, 0) | ShortType | ShortType | - // | IntegerType | X | X | IntegerType | LongType | DoubleType | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(10, 0) | IntegerType | IntegerType | - // | LongType | X | X | X | LongType | DoubleType | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(20, 0) | LongType | LongType | - // | FloatType | X | X | X | X | FloatType | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(30, 15) | DoubleType | X | - // | DoubleType | X | X | X | X | X | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(14, 7) | FloatType | X | - // | Dec(10, 2) | X | X | X | X | DoubleType | DoubleType | Dec(10, 2) | X | X | X | X | X | X | X | X | X | X | DecimalType(10, 2) | Dec(10, 2) | X | - // | BinaryType | X | X | X | X | X | X | X | BinaryType | X | X | X | X | X | X | X | X | X | X | X | X | - // | BooleanType | X | X | X | X | X | X | X | X | BooleanType | X | X | X | X | X | X | X | X | X | X | X | - // | StringType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | - // | DateType | X | X | X | X | X | X | X | X | X | X | DateType | TimestampType | X | X | X | X | X | X | X | X | - // | TimestampType | X | X | X | X | X | X | X | X | X | X | X | TimestampType | X | X | X | X | X | X | X | X | + // | ByteType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(3, 0) | ByteType | ByteType | + // | ShortType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(5, 0) | ShortType | ShortType | + // | IntegerType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 0) | IntegerType | IntegerType | + // | LongType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(20, 0) | LongType | LongType | + // | DoubleType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(30, 15) | DoubleType | IntegerType | + // | FloatType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(14, 7) | FloatType | IntegerType | + // | Dec(10, 2) | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 2) | Dec(10, 2) | IntegerType | + // | BinaryType | X | X | X | X | X | X | X | BinaryType | X | StringType | X | X | X | X | X | X | X | X | X | X | + // | BooleanType | X | X | X | X | X | X | X | X | BooleanType | StringType | X | X | X | X | X | X | X | X | X | X | + // | StringType | X | X | X | X | X | X | X | X | X | StringType | X | X | X | X | X | X | X | X | X | X | + // | DateType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | + // | TimestampType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | // | ArrayType | X | X | X | X | X | X | X | X | X | X | X | X | ArrayType* | X | X | X | X | X | X | X | // | MapType | X | X | X | X | X | X | X | X | X | X | X | X | X | MapType* | X | X | X | X | X | X | // | StructType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | StructType* | X | X | X | X | X | @@ -65,30 +57,10 @@ class AnsiTypeCoercionSuite extends AnalysisTest { // Note: ArrayType* is castable when the element type is castable according to the table. // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit + override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = + AnsiTypeCoercion.implicitCast(e, expectedType) - private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { - // Check default value - val castDefault = AnsiTypeCoercion.implicitCast(default(from), to) - assert(DataType.equalsIgnoreCompatibleNullability( - castDefault.map(_.dataType).getOrElse(null), expected), - s"Failed to cast $from to $to") - - // Check null value - val castNull = AnsiTypeCoercion.implicitCast(createNull(from), to) - assert(DataType.equalsIgnoreCaseAndNullability( - castNull.map(_.dataType).getOrElse(null), expected), - s"Failed to cast $from to $to") - } - - private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { - // Check default value - val castDefault = AnsiTypeCoercion.implicitCast(default(from), to) - assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault") - - // Check null value - val castNull = AnsiTypeCoercion.implicitCast(createNull(from), to) - assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull") - } + override def dateTimeOperationsRule: TypeCoercionRule = AnsiTypeCoercion.DateTimeOperations private def shouldCastStringLiteral(to: AbstractDataType, expected: DataType): Unit = { val input = Literal("123") @@ -110,35 +82,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { assert(castResult.isEmpty, s"Should not be able to cast non-foldable String input to $to") } - private def default(dataType: DataType): Expression = dataType match { - case ArrayType(internalType: DataType, _) => - CreateArray(Seq(Literal.default(internalType))) - case MapType(keyDataType: DataType, valueDataType: DataType, _) => - CreateMap(Seq(Literal.default(keyDataType), Literal.default(valueDataType))) - case _ => Literal.default(dataType) - } - - private def createNull(dataType: DataType): Expression = dataType match { - case ArrayType(internalType: DataType, _) => - CreateArray(Seq(Literal.create(null, internalType))) - case MapType(keyDataType: DataType, valueDataType: DataType, _) => - CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, valueDataType))) - case _ => Literal.create(null, dataType) - } - - // Check whether the type `checkedType` can be cast to all the types in `castableTypes`, - // but cannot be cast to the other types in `allTypes`. - private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = { - val nonCastableTypes = allTypes.filterNot(castableTypes.contains) - - castableTypes.foreach { tpe => - shouldCast(checkedType, tpe, tpe) - } - nonCastableTypes.foreach { tpe => - shouldNotCast(checkedType, tpe) - } - } - private def checkWidenType( widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, @@ -156,81 +99,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { } } - test("implicit type cast - ByteType") { - val checkedType = ByteType - checkTypeCasting(checkedType, castableTypes = numericTypes) - shouldCast(checkedType, DecimalType, DecimalType.ByteDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldCast(checkedType, IntegralType, checkedType) - } - - test("implicit type cast - ShortType") { - val checkedType = ShortType - checkTypeCasting(checkedType, castableTypes = numericTypes.filterNot(_ == ByteType)) - shouldCast(checkedType, DecimalType, DecimalType.ShortDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldCast(checkedType, IntegralType, checkedType) - } - - test("implicit type cast - IntegerType") { - val checkedType = IntegerType - checkTypeCasting(checkedType, castableTypes = - Seq(IntegerType, LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT)) - shouldCast(IntegerType, DecimalType, DecimalType.IntDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldCast(checkedType, IntegralType, checkedType) - } - - test("implicit type cast - LongType") { - val checkedType = LongType - checkTypeCasting(checkedType, castableTypes = - Seq(LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT)) - shouldCast(checkedType, DecimalType, DecimalType.LongDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldCast(checkedType, IntegralType, checkedType) - } - - test("implicit type cast - FloatType") { - val checkedType = FloatType - checkTypeCasting(checkedType, castableTypes = Seq(FloatType, DoubleType)) - shouldCast(checkedType, DecimalType, DecimalType.FloatDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - DoubleType") { - val checkedType = DoubleType - checkTypeCasting(checkedType, castableTypes = Seq(DoubleType)) - shouldCast(checkedType, DecimalType, DecimalType.DoubleDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - DecimalType(10, 2)") { - val checkedType = DecimalType(10, 2) - checkTypeCasting(checkedType, castableTypes = fractionalTypes) - shouldCast(checkedType, DecimalType, checkedType) - shouldCast(checkedType, NumericType, checkedType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - BinaryType") { - val checkedType = BinaryType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - BooleanType") { - val checkedType = BooleanType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - shouldNotCast(checkedType, StringType) - } - test("implicit type cast - unfoldable StringType") { val nonCastableTypes = allTypes.filterNot(_ == StringType) nonCastableTypes.foreach { dt => @@ -251,23 +119,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { shouldCastStringLiteral(NumericType, DoubleType) } - test("implicit type cast - DateType") { - val checkedType = DateType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType, TimestampType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - shouldNotCast(checkedType, StringType) - } - - test("implicit type cast - TimestampType") { - val checkedType = TimestampType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - test("implicit type cast - unfoldable ArrayType(StringType)") { val input = AttributeReference("a", ArrayType(StringType))() val nonCastableTypes = allTypes.filterNot(_ == StringType) @@ -278,55 +129,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { assert(AnsiTypeCoercion.implicitCast(input, NumericType).isEmpty) } - test("implicit type cast - foldable arrayType(StringType)") { - val input = Literal(Array("1")) - assert(AnsiTypeCoercion.implicitCast(input, ArrayType(StringType)) == Some(input)) - (numericTypes ++ datetimeTypes ++ Seq(BinaryType)).foreach { dt => - assert(AnsiTypeCoercion.implicitCast(input, ArrayType(dt)) == - Some(Cast(input, ArrayType(dt)))) - } - } - - test("implicit type cast between two Map types") { - val sourceType = MapType(IntegerType, IntegerType, true) - val castableTypes = - Seq(IntegerType, LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT) - val targetTypes = castableTypes.map { t => - MapType(t, sourceType.valueType, valueContainsNull = true) - } - val nonCastableTargetTypes = allTypes.filterNot(castableTypes.contains(_)).map {t => - MapType(t, sourceType.valueType, valueContainsNull = true) - } - - // Tests that its possible to setup implicit casts between two map types when - // source map's key type is integer and the target map's key type are either Byte, Short, - // Long, Double, Float, Decimal(38, 18) or String. - targetTypes.foreach { targetType => - shouldCast(sourceType, targetType, targetType) - } - - // Tests that its not possible to setup implicit casts between two map types when - // source map's key type is integer and the target map's key type are either Binary, - // Boolean, Date, Timestamp, Array, Struct, CalendarIntervalType or NullType - nonCastableTargetTypes.foreach { targetType => - shouldNotCast(sourceType, targetType) - } - - // Tests that its not possible to cast from nullable map type to not nullable map type. - val targetNotNullableTypes = allTypes.filterNot(_ == IntegerType).map { t => - MapType(t, sourceType.valueType, valueContainsNull = false) - } - val sourceMapExprWithValueNull = - CreateMap(Seq(Literal.default(sourceType.keyType), - Literal.create(null, sourceType.valueType))) - targetNotNullableTypes.foreach { targetType => - val castDefault = - AnsiTypeCoercion.implicitCast(sourceMapExprWithValueNull, targetType) - assert(castDefault.isEmpty, - s"Should not be able to cast $sourceType to $targetType, but got $castDefault") - } - } - test("implicit type cast - StructType().add(\"a1\", StringType)") { val checkedType = new StructType().add("a1", StringType) checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) @@ -345,64 +147,12 @@ class AnsiTypeCoercionSuite extends AnalysisTest { test("implicit type cast - CalendarIntervalType") { val checkedType = CalendarIntervalType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType)) shouldNotCast(checkedType, DecimalType) shouldNotCast(checkedType, NumericType) shouldNotCast(checkedType, IntegralType) } - test("eligible implicit type cast - TypeCollection") { - shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType) - shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType) - shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType) - - shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType) - shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) - shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) - shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) - - shouldCast(DecimalType.SYSTEM_DEFAULT, - TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT) - shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) - shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) - - shouldCast( - ArrayType(StringType, false), - TypeCollection(ArrayType(StringType), StringType), - ArrayType(StringType, false)) - - shouldCast( - ArrayType(StringType, true), - TypeCollection(ArrayType(StringType), StringType), - ArrayType(StringType, true)) - - // When there are multiple convertible types in the `TypeCollection`, use the closest - // convertible data type among convertible types. - shouldCast(IntegerType, TypeCollection(BinaryType, FloatType, LongType), LongType) - shouldCast(ShortType, TypeCollection(BinaryType, LongType, IntegerType), IntegerType) - shouldCast(ShortType, TypeCollection(DateType, LongType, IntegerType, DoubleType), IntegerType) - // If the result is Float type and Double type is also among the convertible target types, - // use Double Type instead of Float type. - shouldCast(LongType, TypeCollection(FloatType, DoubleType, StringType), DoubleType) - } - - test("ineligible implicit type cast - TypeCollection") { - shouldNotCast(IntegerType, TypeCollection(StringType, BinaryType)) - shouldNotCast(IntegerType, TypeCollection(BinaryType, StringType)) - shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) - shouldNotCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType)) - shouldNotCastStringInput(TypeCollection(NumericType, BinaryType)) - // When there are multiple convertible types in the `TypeCollection` and there is no such - // a data type that can be implicit cast to all the other convertible types in the collection. - Seq(TypeCollection(NumericType, BinaryType), - TypeCollection(NumericType, DecimalType, BinaryType), - TypeCollection(IntegerType, LongType, BooleanType), - TypeCollection(DateType, TimestampType, BooleanType)).foreach { typeCollection => - shouldNotCastStringLiteral(typeCollection) - shouldNotCast(NullType, typeCollection) - } - } - test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit = checkWidenType(AnsiTypeCoercion.findTightestCommonType, t1, t2, expected) @@ -606,25 +356,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { None) } - private def ruleTest(rule: Rule[LogicalPlan], - initial: Expression, transformed: Expression): Unit = { - ruleTest(Seq(rule), initial, transformed) - } - - private def ruleTest( - rules: Seq[Rule[LogicalPlan]], - initial: Expression, - transformed: Expression): Unit = { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - val analyzer = new RuleExecutor[LogicalPlan] { - override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) - } - - comparePlans( - analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - test("cast NullType for expressions that implement ExpectsInputTypes") { ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, AnyTypeUnaryExpression(Literal.create(null, NullType)), @@ -1000,90 +731,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { Literal.create(null, IntegerType), Literal.create(null, StringType)))) } - test("type coercion for Concat") { - val rule = AnsiTypeCoercion.ConcatCoercion - - ruleTest(rule, - Concat(Seq(Literal("ab"), Literal("cde"))), - Concat(Seq(Literal("ab"), Literal("cde")))) - ruleTest(rule, - Concat(Seq(Literal(null), Literal("abc"))), - Concat(Seq(Cast(Literal(null), StringType), Literal("abc")))) - ruleTest(rule, - Concat(Seq(Literal(1), Literal("234"))), - Concat(Seq(Literal(1), Literal("234")))) - ruleTest(rule, - Concat(Seq(Literal("1"), Literal("234".getBytes()))), - Concat(Seq(Literal("1"), Literal("234".getBytes())))) - ruleTest(rule, - Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))), - Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1)))) - ruleTest(rule, - Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))), - Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort)))) - ruleTest(rule, - Concat(Seq(Literal(1L), Literal(0.1))), - Concat(Seq(Literal(1L), Literal(0.1)))) - ruleTest(rule, - Concat(Seq(Literal(Decimal(10)))), - Concat(Seq(Literal(Decimal(10))))) - ruleTest(rule, - Concat(Seq(Literal(BigDecimal.valueOf(10)))), - Concat(Seq(Literal(BigDecimal.valueOf(10))))) - ruleTest(rule, - Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))), - Concat(Seq(Literal(java.math.BigDecimal.valueOf(10))))) - ruleTest(rule, - Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), - Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0))))) - - ruleTest(rule, - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) - } - - test("type coercion for Elt") { - val rule = AnsiTypeCoercion.EltCoercion - - ruleTest(rule, - Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), - Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) - ruleTest(rule, - Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), - Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(null), Literal("abc"))), - Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(1), Literal("234"))), - Elt(Seq(Literal(2), Literal(1), Literal("234")))) - ruleTest(rule, - Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), - Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1)))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), - Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), - Elt(Seq(Literal(1), Literal(1L), Literal(0.1)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(Decimal(10)))), - Elt(Seq(Literal(1), Literal(Decimal(10))))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), - Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10))))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), - Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10))))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), - Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0))))) - - ruleTest(rule, - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) - } - private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { logical.output.zip(expectTypes).foreach { case (attr, dt) => assert(attr.dataType === dt) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2dc669bbb9e27..8de84b3ae2fdc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp -import java.time.{Duration, Period} +import java.time.{Duration, LocalDateTime, Period} import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ @@ -31,7 +31,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class TypeCoercionSuite extends AnalysisTest { +abstract class TypeCoercionSuiteBase extends AnalysisTest { import TypeCoercionSuite._ // When Utils.isTesting is true, RuleIdCollection adds individual type coercion rules. Otherwise, @@ -39,59 +39,35 @@ class TypeCoercionSuite extends AnalysisTest { // CombinedTypeCoercionRule. assert(Utils.isTesting, s"${IS_TESTING.key} is not set to true") - // scalastyle:off line.size.limit - // The following table shows all implicit data type conversions that are not visible to the user. - // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // | Source Type\CAST TO | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType | NumericType | IntegralType | - // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // | ByteType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(3, 0) | ByteType | ByteType | - // | ShortType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(5, 0) | ShortType | ShortType | - // | IntegerType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 0) | IntegerType | IntegerType | - // | LongType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(20, 0) | LongType | LongType | - // | DoubleType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(30, 15) | DoubleType | IntegerType | - // | FloatType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(14, 7) | FloatType | IntegerType | - // | Dec(10, 2) | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 2) | Dec(10, 2) | IntegerType | - // | BinaryType | X | X | X | X | X | X | X | BinaryType | X | StringType | X | X | X | X | X | X | X | X | X | X | - // | BooleanType | X | X | X | X | X | X | X | X | BooleanType | StringType | X | X | X | X | X | X | X | X | X | X | - // | StringType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | X | StringType | DateType | TimestampType | X | X | X | X | X | DecimalType(38, 18) | DoubleType | X | - // | DateType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | - // | TimestampType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | - // | ArrayType | X | X | X | X | X | X | X | X | X | X | X | X | ArrayType* | X | X | X | X | X | X | X | - // | MapType | X | X | X | X | X | X | X | X | X | X | X | X | X | MapType* | X | X | X | X | X | X | - // | StructType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | StructType* | X | X | X | X | X | - // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | - // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | - // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: StructType* is castable when all the internal child types are castable according to the table. - // Note: ArrayType* is castable when the element type is castable according to the table. - // Note: MapType* is castable when both the key type and the value type are castable according to the table. - // scalastyle:on line.size.limit + protected def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] + + protected def dateTimeOperationsRule: TypeCoercionRule - private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { + protected def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { // Check default value - val castDefault = TypeCoercion.implicitCast(default(from), to) + val castDefault = implicitCast(default(from), to) assert(DataType.equalsIgnoreCompatibleNullability( castDefault.map(_.dataType).getOrElse(null), expected), s"Failed to cast $from to $to") // Check null value - val castNull = TypeCoercion.implicitCast(createNull(from), to) + val castNull = implicitCast(createNull(from), to) assert(DataType.equalsIgnoreCaseAndNullability( castNull.map(_.dataType).getOrElse(null), expected), s"Failed to cast $from to $to") } - private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { + protected def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { // Check default value - val castDefault = TypeCoercion.implicitCast(default(from), to) + val castDefault = implicitCast(default(from), to) assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault") // Check null value - val castNull = TypeCoercion.implicitCast(createNull(from), to) + val castNull = implicitCast(createNull(from), to) assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull") } - private def default(dataType: DataType): Expression = dataType match { + protected def default(dataType: DataType): Expression = dataType match { case ArrayType(internalType: DataType, _) => CreateArray(Seq(Literal.default(internalType))) case MapType(keyDataType: DataType, valueDataType: DataType, _) => @@ -99,7 +75,7 @@ class TypeCoercionSuite extends AnalysisTest { case _ => Literal.default(dataType) } - private def createNull(dataType: DataType): Expression = dataType match { + protected def createNull(dataType: DataType): Expression = dataType match { case ArrayType(internalType: DataType, _) => CreateArray(Seq(Literal.create(null, internalType))) case MapType(keyDataType: DataType, valueDataType: DataType, _) => @@ -109,7 +85,7 @@ class TypeCoercionSuite extends AnalysisTest { // Check whether the type `checkedType` can be cast to all the types in `castableTypes`, // but cannot be cast to the other types in `allTypes`. - private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = { + protected def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = { val nonCastableTypes = allTypes.filterNot(castableTypes.contains) castableTypes.foreach { tpe => @@ -120,21 +96,23 @@ class TypeCoercionSuite extends AnalysisTest { } } - private def checkWidenType( - widenFunc: (DataType, DataType) => Option[DataType], - t1: DataType, - t2: DataType, - expected: Option[DataType], - isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2) - assert(found == expected, - s"Expected $expected as wider common type for $t1 and $t2, found $found") - // Test both directions to make sure the widening is symmetric. - if (isSymmetric) { - found = widenFunc(t2, t1) - assert(found == expected, - s"Expected $expected as wider common type for $t2 and $t1, found $found") + protected def ruleTest(rule: Rule[LogicalPlan], + initial: Expression, transformed: Expression): Unit = { + ruleTest(Seq(rule), initial, transformed) + } + + protected def ruleTest( + rules: Seq[Rule[LogicalPlan]], + initial: Expression, + transformed: Expression): Unit = { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + val analyzer = new RuleExecutor[LogicalPlan] { + override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) } + + comparePlans( + analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) } test("implicit type cast - ByteType") { @@ -209,16 +187,6 @@ class TypeCoercionSuite extends AnalysisTest { shouldNotCast(checkedType, IntegralType) } - test("implicit type cast - StringType") { - val checkedType = StringType - val nonCastableTypes = - complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) - checkTypeCasting(checkedType, castableTypes = allTypes.filterNot(nonCastableTypes.contains)) - shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) - shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) - shouldNotCast(checkedType, IntegralType) - } - test("implicit type cast - DateType") { val checkedType = DateType checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType, TimestampType)) @@ -235,20 +203,6 @@ class TypeCoercionSuite extends AnalysisTest { shouldNotCast(checkedType, IntegralType) } - test("implicit type cast - ArrayType(StringType)") { - val checkedType = ArrayType(StringType) - val nonCastableTypes = - complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) - checkTypeCasting(checkedType, - castableTypes = allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_))) - nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _)) - shouldNotCast(ArrayType(DoubleType, containsNull = false), - ArrayType(LongType, containsNull = false)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - test("implicit type cast between two Map types") { val sourceType = MapType(IntegerType, IntegerType, true) val castableTypes = numericTypes ++ Seq(StringType).filter(!Cast.forceNullable(IntegerType, _)) @@ -288,30 +242,6 @@ class TypeCoercionSuite extends AnalysisTest { } } - test("implicit type cast - StructType().add(\"a1\", StringType)") { - val checkedType = new StructType().add("a1", StringType) - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - NullType") { - val checkedType = NullType - checkTypeCasting(checkedType, castableTypes = allTypes) - shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) - shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) - shouldCast(checkedType, IntegralType, IntegralType.defaultConcreteType) - } - - test("implicit type cast - CalendarIntervalType") { - val checkedType = CalendarIntervalType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - test("eligible implicit type cast - TypeCollection") { shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType) @@ -333,8 +263,6 @@ class TypeCoercionSuite extends AnalysisTest { shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) - shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType) - shouldCast( ArrayType(StringType, false), TypeCollection(ArrayType(StringType), StringType), @@ -350,6 +278,249 @@ class TypeCoercionSuite extends AnalysisTest { shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) } + test("type coercion for Concat") { + val rule = TypeCoercion.ConcatCoercion + + ruleTest(rule, + Concat(Seq(Literal("ab"), Literal("cde"))), + Concat(Seq(Literal("ab"), Literal("cde")))) + ruleTest(rule, + Concat(Seq(Literal(null), Literal("abc"))), + Concat(Seq(Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Concat(Seq(Literal(1), Literal("234"))), + Concat(Seq(Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Concat(Seq(Literal("1"), Literal("234".getBytes()))), + Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))), + Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))), + Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(0.1))), + Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(Decimal(10)))), + Concat(Seq(Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(BigDecimal.valueOf(10)))), + Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))), + Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "true") { + ruleTest(rule, + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), + Concat(Seq(Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "false") { + ruleTest(rule, + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) + } + } + + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + + test("Datetime operations") { + val rule = dateTimeOperationsRule + val dateLiteral = Literal(java.sql.Date.valueOf("2021-01-01")) + val timestampLiteral = Literal(Timestamp.valueOf("2021-01-01 00:00:00")) + val timestampNTZLiteral = Literal(LocalDateTime.parse("2021-01-01T00:00:00")) + val intLiteral = Literal(3) + Seq(timestampLiteral, timestampNTZLiteral).foreach { tsLiteral => + ruleTest(rule, + DateAdd(tsLiteral, intLiteral), + DateAdd(Cast(tsLiteral, DateType), intLiteral)) + ruleTest(rule, + DateSub(tsLiteral, intLiteral), + DateSub(Cast(tsLiteral, DateType), intLiteral)) + ruleTest(rule, + SubtractTimestamps(tsLiteral, dateLiteral), + SubtractTimestamps(tsLiteral, Cast(dateLiteral, tsLiteral.dataType))) + ruleTest(rule, + SubtractTimestamps(dateLiteral, tsLiteral), + SubtractTimestamps(Cast(dateLiteral, tsLiteral.dataType), tsLiteral)) + } + + ruleTest(rule, + SubtractTimestamps(timestampLiteral, timestampNTZLiteral), + SubtractTimestamps(Cast(timestampLiteral, TimestampNTZType), timestampNTZLiteral)) + ruleTest(rule, + SubtractTimestamps(timestampNTZLiteral, timestampLiteral), + SubtractTimestamps(timestampNTZLiteral, Cast(timestampLiteral, TimestampNTZType))) + } + +} + +class TypeCoercionSuite extends TypeCoercionSuiteBase { + import TypeCoercionSuite._ + + // scalastyle:off line.size.limit + // The following table shows all implicit data type conversions that are not visible to the user. + // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ + // | Source Type\CAST TO | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType | NumericType | IntegralType | + // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ + // | ByteType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(3, 0) | ByteType | ByteType | + // | ShortType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(5, 0) | ShortType | ShortType | + // | IntegerType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 0) | IntegerType | IntegerType | + // | LongType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(20, 0) | LongType | LongType | + // | DoubleType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(30, 15) | DoubleType | IntegerType | + // | FloatType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(14, 7) | FloatType | IntegerType | + // | Dec(10, 2) | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 2) | Dec(10, 2) | IntegerType | + // | BinaryType | X | X | X | X | X | X | X | BinaryType | X | StringType | X | X | X | X | X | X | X | X | X | X | + // | BooleanType | X | X | X | X | X | X | X | X | BooleanType | StringType | X | X | X | X | X | X | X | X | X | X | + // | StringType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | X | StringType | DateType | TimestampType | X | X | X | X | X | DecimalType(38, 18) | DoubleType | X | + // | DateType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | + // | TimestampType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | + // | ArrayType | X | X | X | X | X | X | X | X | X | X | X | X | ArrayType* | X | X | X | X | X | X | X | + // | MapType | X | X | X | X | X | X | X | X | X | X | X | X | X | MapType* | X | X | X | X | X | X | + // | StructType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | StructType* | X | X | X | X | X | + // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | + // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | + // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ + // Note: StructType* is castable when all the internal child types are castable according to the table. + // Note: ArrayType* is castable when the element type is castable according to the table. + // Note: MapType* is castable when both the key type and the value type are castable according to the table. + // scalastyle:on line.size.limit + override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = + TypeCoercion.implicitCast(e, expectedType) + + override def dateTimeOperationsRule: TypeCoercionRule = TypeCoercion.DateTimeOperations + + private def checkWidenType( + widenFunc: (DataType, DataType) => Option[DataType], + t1: DataType, + t2: DataType, + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + var found = widenFunc(t1, t2) + assert(found == expected, + s"Expected $expected as wider common type for $t1 and $t2, found $found") + // Test both directions to make sure the widening is symmetric. + if (isSymmetric) { + found = widenFunc(t2, t1) + assert(found == expected, + s"Expected $expected as wider common type for $t2 and $t1, found $found") + } + } + + test("implicit type cast - StringType") { + val checkedType = StringType + val nonCastableTypes = + complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) + checkTypeCasting(checkedType, castableTypes = allTypes.filterNot(nonCastableTypes.contains)) + shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) + shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - ArrayType(StringType)") { + val checkedType = ArrayType(StringType) + val nonCastableTypes = + complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) + checkTypeCasting(checkedType, + castableTypes = allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_))) + nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _)) + shouldNotCast(ArrayType(DoubleType, containsNull = false), + ArrayType(LongType, containsNull = false)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - StructType().add(\"a1\", StringType)") { + val checkedType = new StructType().add("a1", StringType) + checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - NullType") { + val checkedType = NullType + checkTypeCasting(checkedType, castableTypes = allTypes) + shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) + shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) + shouldCast(checkedType, IntegralType, IntegralType.defaultConcreteType) + } + + test("implicit type cast - CalendarIntervalType") { + val checkedType = CalendarIntervalType + checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("eligible implicit type cast - TypeCollection II") { + shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType) + } + test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit = checkWidenType(TypeCoercion.findTightestCommonType, t1, t2, expected) @@ -717,25 +888,6 @@ class TypeCoercionSuite extends AnalysisTest { Some(new StructType().add("a", StringType))) } - private def ruleTest(rule: Rule[LogicalPlan], - initial: Expression, transformed: Expression): Unit = { - ruleTest(Seq(rule), initial, transformed) - } - - private def ruleTest( - rules: Seq[Rule[LogicalPlan]], - initial: Expression, - transformed: Expression): Unit = { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - val analyzer = new RuleExecutor[LogicalPlan] { - override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) - } - - comparePlans( - analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - test("cast NullType for expressions that implement ExpectsInputTypes") { ruleTest(TypeCoercion.ImplicitTypeCasts, AnyTypeUnaryExpression(Literal.create(null, NullType)), @@ -1110,114 +1262,6 @@ class TypeCoercionSuite extends AnalysisTest { Literal.create(null, IntegerType), Literal.create(null, StringType)))) } - test("type coercion for Concat") { - val rule = TypeCoercion.ConcatCoercion - - ruleTest(rule, - Concat(Seq(Literal("ab"), Literal("cde"))), - Concat(Seq(Literal("ab"), Literal("cde")))) - ruleTest(rule, - Concat(Seq(Literal(null), Literal("abc"))), - Concat(Seq(Cast(Literal(null), StringType), Literal("abc")))) - ruleTest(rule, - Concat(Seq(Literal(1), Literal("234"))), - Concat(Seq(Cast(Literal(1), StringType), Literal("234")))) - ruleTest(rule, - Concat(Seq(Literal("1"), Literal("234".getBytes()))), - Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))), - Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), - Cast(Literal(0.1), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))), - Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), - Cast(Literal(3.toShort), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(1L), Literal(0.1))), - Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(Decimal(10)))), - Concat(Seq(Cast(Literal(Decimal(10)), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(BigDecimal.valueOf(10)))), - Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))), - Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), - Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType), - Cast(Literal(new Timestamp(0)), StringType)))) - - withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "true") { - ruleTest(rule, - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), - Concat(Seq(Cast(Literal("123".getBytes), StringType), - Cast(Literal("456".getBytes), StringType)))) - } - - withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "false") { - ruleTest(rule, - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) - } - } - - test("type coercion for Elt") { - val rule = TypeCoercion.EltCoercion - - ruleTest(rule, - Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), - Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) - ruleTest(rule, - Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), - Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(null), Literal("abc"))), - Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(1), Literal("234"))), - Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) - ruleTest(rule, - Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), - Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), - Cast(Literal(0.1), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), - Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), - Cast(Literal(3.toShort), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), - Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(Decimal(10)))), - Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), - Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), - Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), - Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), - Cast(Literal(new Timestamp(0)), StringType)))) - - withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "true") { - ruleTest(rule, - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), - Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), - Cast(Literal("456".getBytes), StringType)))) - } - - withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "false") { - ruleTest(rule, - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) - } - } - test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out index ac64bf6c05eab..b95c8dac9a82c 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out @@ -219,10 +219,9 @@ struct -- !query select next_day(timestamp_ltz"2015-07-23 12:12:12", "Mon") -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'next_day(TIMESTAMP '2015-07-23 12:12:12', 'Mon')' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP '2015-07-23 12:12:12'' is of timestamp type.; line 1 pos 7 +2015-07-27 -- !query @@ -354,19 +353,17 @@ NULL -- !query select date_add(timestamp_ltz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(TIMESTAMP '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP '2011-11-11 12:12:12'' is of timestamp type.; line 1 pos 7 +2011-11-12 -- !query select date_add(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-12 -- !query @@ -464,19 +461,17 @@ NULL -- !query select date_sub(timestamp_ltz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(TIMESTAMP '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP '2011-11-11 12:12:12'' is of timestamp type.; line 1 pos 7 +2011-11-10 -- !query select date_sub(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-10 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 12c98ff138da8..e9c323254b4a1 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -639,8 +639,8 @@ select make_interval(0, 0, 0, 0, 0, 0, 1234567890123456789) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'make_interval(0, 0, 0, 0, 0, 0, 1234567890123456789L)' due to data type mismatch: argument 7 requires decimal(18,6) type, however, '1234567890123456789L' is of bigint type.; line 1 pos 7 +org.apache.spark.SparkArithmeticException +Decimal(expanded,1234567890123456789,20,0}) cannot be represented as Decimal(18, 6). If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 879d89eb5074b..45d403859a2cf 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -71,10 +71,9 @@ ab abcd ab NULL -- !query select left(null, -2) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring(NULL, 1, -2)' due to data type mismatch: argument 1 requires (string or binary) type, however, 'NULL' is of void type.; line 1 pos 7 +NULL -- !query @@ -89,19 +88,17 @@ invalid input syntax for type numeric: a. To return NULL instead, use 'try_cast' -- !query select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring('abcd', (- CAST('2' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('2' AS DOUBLE))' is of double type.; line 1 pos 43 +cd abcd cd NULL -- !query select right(null, -2) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring(NULL, (- -2), 2147483647)' due to data type mismatch: argument 1 requires (string or binary) type, however, 'NULL' is of void type.; line 1 pos 7 +NULL -- !query @@ -109,8 +106,8 @@ select right("abcd", -2), right("abcd", 0), right("abcd", 'a') -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring('abcd', (- CAST('a' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('a' AS DOUBLE))' is of double type.; line 1 pos 44 +java.lang.NumberFormatException +invalid input syntax for type numeric: a. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query @@ -308,28 +305,25 @@ trim -- !query SELECT btrim(encode(" xyz ", 'utf-8')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'trim(encode(' xyz ', 'utf-8'))' due to data type mismatch: argument 1 requires string type, however, 'encode(' xyz ', 'utf-8')' is of binary type.; line 1 pos 7 +xyz -- !query SELECT btrim(encode('yxTomxx', 'utf-8'), encode('xyz', 'utf-8')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH encode('xyz', 'utf-8') FROM encode('yxTomxx', 'utf-8'))' due to data type mismatch: argument 1 requires string type, however, 'encode('yxTomxx', 'utf-8')' is of binary type. argument 2 requires string type, however, 'encode('xyz', 'utf-8')' is of binary type.; line 1 pos 7 +Tom -- !query SELECT btrim(encode('xxxbarxxx', 'utf-8'), encode('x', 'utf-8')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH encode('x', 'utf-8') FROM encode('xxxbarxxx', 'utf-8'))' due to data type mismatch: argument 1 requires string type, however, 'encode('xxxbarxxx', 'utf-8')' is of binary type. argument 2 requires string type, however, 'encode('x', 'utf-8')' is of binary type.; line 1 pos 7 +bar -- !query @@ -545,37 +539,33 @@ AABB -- !query SELECT lpad('abc', 5, x'57') -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'lpad('abc', 5, X'57')' due to data type mismatch: argument 3 requires string type, however, 'X'57'' is of binary type.; line 1 pos 7 +WWabc -- !query SELECT lpad(x'57', 5, 'abc') -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'lpad(X'57', 5, 'abc')' due to data type mismatch: argument 1 requires string type, however, 'X'57'' is of binary type.; line 1 pos 7 +abcaW -- !query SELECT rpad('abc', 5, x'57') -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'rpad('abc', 5, X'57')' due to data type mismatch: argument 3 requires string type, however, 'X'57'' is of binary type.; line 1 pos 7 +abcWW -- !query SELECT rpad(x'57', 5, 'abc') -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'rpad(X'57', 5, 'abc')' due to data type mismatch: argument 1 requires string type, however, 'X'57'' is of binary type.; line 1 pos 7 +Wabca -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/date.sql.out b/sql/core/src/test/resources/sql-tests/results/date.sql.out index 2eacb6cdce66d..562028945103e 100644 --- a/sql/core/src/test/resources/sql-tests/results/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/date.sql.out @@ -349,10 +349,9 @@ struct -- !query select date_add(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-12 -- !query @@ -458,10 +457,9 @@ struct -- !query select date_sub(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-10 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out index 573ce3db9e5b2..74480ab6cc2b4 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out @@ -349,10 +349,9 @@ struct -- !query select date_add(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-12 -- !query @@ -458,10 +457,9 @@ struct -- !query select date_sub(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-10 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out index 8a4ee142011ce..bc13bb893b118 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out @@ -4793,43 +4793,33 @@ Infinity -- !query select * from range(cast(0.0 as decimal(38, 18)), cast(4.0 as decimal(38, 18))) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -Table-valued function range with alternatives: - range(start: long, end: long, step: long, numSlices: integer) - range(start: long, end: long, step: long) - range(start: long, end: long) - range(end: long) -cannot be applied to (decimal(38,18), decimal(38,18)): Incompatible input data type. Expected: long; Found: decimal(38,18); line 1 pos 14 +0 +1 +2 +3 -- !query select * from range(cast(0.1 as decimal(38, 18)), cast(4.0 as decimal(38, 18)), cast(1.3 as decimal(38, 18))) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -Table-valued function range with alternatives: - range(start: long, end: long, step: long, numSlices: integer) - range(start: long, end: long, step: long) - range(start: long, end: long) - range(end: long) -cannot be applied to (decimal(38,18), decimal(38,18), decimal(38,18)): Incompatible input data type. Expected: long; Found: decimal(38,18); line 1 pos 14 +0 +1 +2 +3 -- !query select * from range(cast(4.0 as decimal(38, 18)), cast(-1.5 as decimal(38, 18)), cast(-2.2 as decimal(38, 18))) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -Table-valued function range with alternatives: - range(start: long, end: long, step: long, numSlices: integer) - range(start: long, end: long, step: long) - range(start: long, end: long) - range(end: long) -cannot be applied to (decimal(38,18), decimal(38,18), decimal(38,18)): Incompatible input data type. Expected: long; Found: decimal(38,18); line 1 pos 14 +0 +2 +4 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out index 253a5e49b81fa..28904629df373 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out @@ -977,37 +977,33 @@ struct -- !query SELECT trim(binary('\\000') from binary('\\000Tom\\000')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH CAST('\\000' AS BINARY) FROM CAST('\\000Tom\\000' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('\\000Tom\\000' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('\\000' AS BINARY)' is of binary type.; line 1 pos 7 +Tom -- !query SELECT btrim(binary('\\000trim\\000'), binary('\\000')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH CAST('\\000' AS BINARY) FROM CAST('\\000trim\\000' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('\\000trim\\000' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('\\000' AS BINARY)' is of binary type.; line 1 pos 7 +trim -- !query SELECT btrim(binary(''), binary('\\000')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH CAST('\\000' AS BINARY) FROM CAST('' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('\\000' AS BINARY)' is of binary type.; line 1 pos 7 + -- !query SELECT btrim(binary('\\000trim\\000'), binary('')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH CAST('' AS BINARY) FROM CAST('\\000trim\\000' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('\\000trim\\000' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('' AS BINARY)' is of binary type.; line 1 pos 7 +\000trim\000 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out index 25bcdb4b0cceb..75ca4f27ba49f 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out @@ -54,10 +54,9 @@ struct -- !query select length(42) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'length(42)' due to data type mismatch: argument 1 requires (string or binary) type, however, '42' is of int type.; line 1 pos 7 +2 -- !query @@ -65,8 +64,8 @@ select string('four: ') || 2+2 -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat(CAST('four: ' AS STRING), 2)' due to data type mismatch: input to function concat should have been string, binary or array, but it's [string, int]; line 1 pos 7 +java.lang.NumberFormatException +invalid input syntax for type numeric: four: 2. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query @@ -74,17 +73,16 @@ select 'four: ' || 2+2 -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat('four: ', 2)' due to data type mismatch: input to function concat should have been string, binary or array, but it's [string, int]; line 1 pos 7 +java.lang.NumberFormatException +invalid input syntax for type numeric: four: 2. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query select 3 || 4.0 -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat(3, 4.0BD)' due to data type mismatch: input to function concat should have been string, binary or array, but it's [int, decimal(2,1)]; line 1 pos 7 +34.0 -- !query @@ -101,10 +99,9 @@ one -- !query select concat(1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat(1, 2, 3, 'hello', true, false, to_date('20100309', 'yyyyMMdd'))' due to data type mismatch: input to function concat should have been string, binary or array, but it's [int, int, int, string, boolean, boolean, date]; line 1 pos 7 +123hellotruefalse2010-03-09 -- !query @@ -118,37 +115,33 @@ one -- !query select concat_ws('#',1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws('#', 1, 2, 3, 'hello', true, false, to_date('20100309', 'yyyyMMdd'))' due to data type mismatch: argument 2 requires (array or string) type, however, '1' is of int type. argument 3 requires (array or string) type, however, '2' is of int type. argument 4 requires (array or string) type, however, '3' is of int type. argument 6 requires (array or string) type, however, 'true' is of boolean type. argument 7 requires (array or string) type, however, 'false' is of boolean type. argument 8 requires (array or string) type, however, 'to_date('20100309', 'yyyyMMdd')' is of date type.; line 1 pos 7 +1#x#x#hello#true#false#x-03-09 -- !query select concat_ws(',',10,20,null,30) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws(',', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 4 requires (array or string) type, however, 'NULL' is of void type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 +10,20,30 -- !query select concat_ws('',10,20,null,30) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws('', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 4 requires (array or string) type, however, 'NULL' is of void type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 +102030 -- !query select concat_ws(NULL,10,20,null,30) is null -- !query schema -struct<> +struct<(concat_ws(NULL, 10, 20, NULL, 30) IS NULL):boolean> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws(CAST(NULL AS STRING), 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 4 requires (array or string) type, however, 'NULL' is of void type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 +true -- !query @@ -162,10 +155,19 @@ edcba -- !query select i, left('ahoj', i), right('ahoj', i) from range(-5, 6) t(i) order by i -- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring('ahoj', 1, t.i)' due to data type mismatch: argument 3 requires int type, however, 't.i' is of bigint type.; line 1 pos 10 +struct +-- !query output +-5 +-4 +-3 +-2 +-1 +0 +1 a j +2 ah oj +3 aho hoj +4 ahoj ahoj +5 ahoj ahoj -- !query