diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index dfaac92e04a2..2c00957bd6af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -101,13 +101,11 @@ object TypeCoercion { case _ => None } - /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */ - def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = { - findTightestCommonType(left, right).orElse((left, right) match { - case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) - case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) - case _ => None - }) + /** Promotes all the way to StringType. */ + private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match { + case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) + case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) + case _ => None } /** @@ -117,21 +115,17 @@ object TypeCoercion { * loss of precision when widening decimal and double, and promotion to string. */ private[analysis] def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { - (t1, t2) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => - Some(DoubleType) - case _ => - findTightestCommonTypeToString(t1, t2) - } + findTightestCommonType(t1, t2) + .orElse(findWiderTypeForDecimal(t1, t2)) + .orElse(stringPromotion(t1, t2)) + .orElse((t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + case _ => None + }) } - private def findWiderCommonType(types: Seq[DataType]) = { + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case Some(d) => findWiderTypeForTwo(d, c) case None => None @@ -139,27 +133,49 @@ object TypeCoercion { } /** - * Similar to [[findWiderCommonType]] that can handle decimal types, but can't promote to + * Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to * string. If the wider decimal type exceeds system limitation, this rule will truncate * the decimal type before return it. */ - def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findTightestCommonType(d, c).orElse((d, c) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => - Some(DoubleType) + private[analysis] def findWiderTypeWithoutStringPromotionForTwo( + t1: DataType, + t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) + .orElse(findWiderTypeForDecimal(t1, t2)) + .orElse((t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findWiderTypeWithoutStringPromotionForTwo(et1, et2) + .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) + } + + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) case None => None }) } + /** + * Finds a wider type when one or both types are decimals. If the wider decimal type exceeds + * system limitation, this rule will truncate the decimal type. If a decimal and other fractional + * types are compared, returns a double type. + */ + private def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = { + (dt1, dt2) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => + Some(DoubleType) + case _ => None + } + } + private def haveSameType(exprs: Seq[Expression]): Boolean = exprs.map(_.dataType).distinct.length == 1 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index ceb5b53e0847..3e0c357b6de4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -53,7 +53,8 @@ class TypeCoercionSuite extends PlanTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: ArrayType*, MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable + // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable. + // Note: ArrayType* is castable when the element type is castable according to the table. // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { @@ -125,6 +126,20 @@ class TypeCoercionSuite extends PlanTest { } } + private def checkWidenType( + widenFunc: (DataType, DataType) => Option[DataType], + t1: DataType, + t2: DataType, + expected: Option[DataType]): Unit = { + var found = widenFunc(t1, t2) + assert(found == expected, + s"Expected $expected as wider common type for $t1 and $t2, found $found") + // Test both directions to make sure the widening is symmetric. + found = widenFunc(t2, t1) + assert(found == expected, + s"Expected $expected as wider common type for $t2 and $t1, found $found") + } + test("implicit type cast - ByteType") { val checkedType = ByteType checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType)) @@ -308,15 +323,8 @@ class TypeCoercionSuite extends PlanTest { } test("tightest common bound for types") { - def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = TypeCoercion.findTightestCommonType(t1, t2) - assert(found == tightestCommon, - s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") - // Test both directions to make sure the widening is symmetric. - found = TypeCoercion.findTightestCommonType(t2, t1) - assert(found == tightestCommon, - s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") - } + def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit = + checkWidenType(TypeCoercion.findTightestCommonType, t1, t2, expected) // Null widenTest(NullType, NullType, Some(NullType)) @@ -355,7 +363,6 @@ class TypeCoercionSuite extends PlanTest { widenTest(DecimalType(2, 1), DoubleType, None) widenTest(DecimalType(2, 1), IntegerType, None) widenTest(DoubleType, DecimalType(2, 1), None) - widenTest(IntegerType, DecimalType(2, 1), None) // StringType widenTest(NullType, StringType, Some(StringType)) @@ -379,6 +386,60 @@ class TypeCoercionSuite extends PlanTest { widenTest(ArrayType(IntegerType), StructType(Seq()), None) } + test("wider common type for decimal and array") { + def widenTestWithStringPromotion( + t1: DataType, + t2: DataType, + expected: Option[DataType]): Unit = { + checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected) + } + + def widenTestWithoutStringPromotion( + t1: DataType, + t2: DataType, + expected: Option[DataType]): Unit = { + checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected) + } + + // Decimal + widenTestWithStringPromotion( + DecimalType(2, 1), DecimalType(3, 2), Some(DecimalType(3, 2))) + widenTestWithStringPromotion( + DecimalType(2, 1), DoubleType, Some(DoubleType)) + widenTestWithStringPromotion( + DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1))) + widenTestWithStringPromotion( + DecimalType(2, 1), LongType, Some(DecimalType(21, 1))) + + // ArrayType + widenTestWithStringPromotion( + ArrayType(ShortType, containsNull = true), + ArrayType(DoubleType, containsNull = false), + Some(ArrayType(DoubleType, containsNull = true))) + widenTestWithStringPromotion( + ArrayType(TimestampType, containsNull = false), + ArrayType(StringType, containsNull = true), + Some(ArrayType(StringType, containsNull = true))) + widenTestWithStringPromotion( + ArrayType(ArrayType(IntegerType), containsNull = false), + ArrayType(ArrayType(LongType), containsNull = false), + Some(ArrayType(ArrayType(LongType), containsNull = false))) + + // Without string promotion + widenTestWithoutStringPromotion(IntegerType, StringType, None) + widenTestWithoutStringPromotion(StringType, TimestampType, None) + widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) + widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) + + // String promotion + widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) + widenTestWithStringPromotion(StringType, TimestampType, Some(StringType)) + widenTestWithStringPromotion( + ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType))) + widenTestWithStringPromotion( + ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) + } + private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { ruleTest(Seq(rule), initial, transformed) }