Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ object TypeCoercion {
}
}

def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = {
if (types.isEmpty) {
None
} else {
types.tail.foldLeft[Option[DataType]](Some(types.head)) {
case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2)
case _ => None
}
}
}

/**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
Expand Down Expand Up @@ -259,8 +270,25 @@ object TypeCoercion {
}
}

private def haveSameType(exprs: Seq[Expression]): Boolean =
exprs.map(_.dataType).distinct.length == 1
/**
* Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull.
*/
def haveSameType(types: Seq[DataType]): Boolean = {
if (types.size <= 1) {
true
} else {
val head = types.head
types.tail.forall(_.sameType(head))
}
}

private def castIfNotSameType(expr: Expression, dt: DataType): Expression = {
if (!expr.dataType.sameType(dt)) {
Cast(expr, dt)
} else {
expr
}
}

/**
* Widens numeric types and converts strings to numbers when appropriate.
Expand Down Expand Up @@ -525,23 +553,24 @@ object TypeCoercion {
* This ensure that the types for various functions are as expected.
*/
object FunctionArgumentConversion extends TypeCoercionRule {

override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case a @ CreateArray(children) if !haveSameType(children) =>
case a @ CreateArray(children) if !haveSameType(children.map(_.dataType)) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType)))
case None => a
}

case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
!haveSameType(c.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType)))
case None => c
}

Expand All @@ -553,41 +582,34 @@ object TypeCoercion {
case None => aj
}

case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) =>
case s @ Sequence(_, _, _, timeZoneId)
if !haveSameType(s.coercibleChildren.map(_.dataType)) =>
val types = s.coercibleChildren.map(_.dataType)
findWiderCommonType(types) match {
case Some(widerDataType) => s.castChildrenTo(widerDataType)
case None => s
}

case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
!haveSameType(m.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType)))
case None => m
}

case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
m.keys
} else {
val types = m.keys.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
case None => m.keys
}
(!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) =>
val keyTypes = m.keys.map(_.dataType)
val newKeys = findWiderCommonType(keyTypes) match {
case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType))
case None => m.keys
}

val newValues = if (haveSameType(m.values)) {
m.values
} else {
val types = m.values.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
case None => m.values
}
val valueTypes = m.values.map(_.dataType)
val newValues = findWiderCommonType(valueTypes) match {
case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType))
case None => m.values
}

CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
Expand All @@ -610,27 +632,27 @@ object TypeCoercion {
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case c @ Coalesce(es) if !haveSameType(es) =>
case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) =>
val types = es.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType)))
case None => c
}

// When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
// we need to truncate, but we should not promote one side to string if the other side is
// string.g
case g @ Greatest(children) if !haveSameType(children) =>
case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType)))
case None => g
}

case l @ Least(children) if !haveSameType(children) =>
case l @ Least(children) if !haveSameType(l.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType)))
case None => l
}

Expand Down Expand Up @@ -672,27 +694,14 @@ object TypeCoercion {
object CaseWhenCoercion extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual =>
case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) =>
val maybeCommonType = findWiderCommonType(c.inputTypesForMerging)
maybeCommonType.map { commonType =>
var changed = false
val newBranches = c.branches.map { case (condition, value) =>
if (value.dataType.sameType(commonType)) {
(condition, value)
} else {
changed = true
(condition, Cast(value, commonType))
}
}
val newElseValue = c.elseValue.map { value =>
if (value.dataType.sameType(commonType)) {
value
} else {
changed = true
Cast(value, commonType)
}
(condition, castIfNotSameType(value, commonType))
}
if (changed) CaseWhen(newBranches, newElseValue) else c
val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType))
CaseWhen(newBranches, newElseValue)
}.getOrElse(c)
}
}
Expand All @@ -705,10 +714,10 @@ object TypeCoercion {
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e if !e.childrenResolved => e
// Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual =>
case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) =>
findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType)
val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType)
val newLeft = castIfNotSameType(left, widestType)
val newRight = castIfNotSameType(right, widestType)
If(pred, newLeft, newRight)
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
case If(Literal(null, NullType), left, right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,22 +709,12 @@ trait ComplexTypeMergingExpression extends Expression {
@transient
lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType)

/**
* A method determining whether the input types are equal ignoring nullable, containsNull and
* valueContainsNull flags and thus convenient for resolution of the final data type.
*/
def areInputTypesForMergingEqual: Boolean = {
inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall {
case Seq(dt1, dt2) => dt1.sameType(dt2)
}
}

override def dataType: DataType = {
require(
inputTypesForMerging.nonEmpty,
"The collection of input data types must not be empty.")
require(
areInputTypesForMergingEqual,
TypeCoercion.haveSameType(inputTypesForMerging),
"All input types must be the same except nullable, containsNull, valueContainsNull flags.")
inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
Expand Down Expand Up @@ -514,7 +514,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
> SELECT _FUNC_(10, 9, 2, 4, 3);
2
""")
case class Least(children: Seq[Expression]) extends Expression {
case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression {

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand All @@ -525,7 +525,7 @@ case class Least(children: Seq[Expression]) extends Expression {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
} else if (!TypeCoercion.haveSameType(inputTypesForMerging)) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
Expand All @@ -534,8 +534,6 @@ case class Least(children: Seq[Expression]) extends Expression {
}
}

override def dataType: DataType = children.head.dataType

override def eval(input: InternalRow): Any = {
children.foldLeft[Any](null)((r, c) => {
val evalc = c.eval(input)
Expand Down Expand Up @@ -589,7 +587,7 @@ case class Least(children: Seq[Expression]) extends Expression {
> SELECT _FUNC_(10, 9, 2, 4, 3);
10
""")
case class Greatest(children: Seq[Expression]) extends Expression {
case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression {

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand All @@ -600,7 +598,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
} else if (!TypeCoercion.haveSameType(inputTypesForMerging)) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
Expand All @@ -609,8 +607,6 @@ case class Greatest(children: Seq[Expression]) extends Expression {
}
}

override def dataType: DataType = children.head.dataType

override def eval(input: InternalRow): Any = {
children.foldLeft[Any](null)((r, c) => {
val evalc = c.eval(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
[[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]]
""", since = "2.4.0")
case class MapConcat(children: Seq[Expression]) extends Expression {
case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression {

override def checkInputDataTypes(): TypeCheckResult = {
var funcName = s"function $prettyName"
Expand All @@ -521,14 +521,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression {
}

override def dataType: MapType = {
val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption
.getOrElse(MapType(StringType, StringType))
val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType])
.exists(_.valueContainsNull)
if (dt.valueContainsNull != valueContainsNull) {
dt.copy(valueContainsNull = valueContainsNull)
if (children.isEmpty) {
MapType(StringType, StringType)
} else {
dt
super.dataType.asInstanceOf[MapType]
}
}

Expand Down Expand Up @@ -2211,7 +2207,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
| [1,2,3,4,5,6]
""")
case class Concat(children: Seq[Expression]) extends Expression {
case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression {

private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

Expand All @@ -2232,7 +2228,13 @@ case class Concat(children: Seq[Expression]) extends Expression {
}
}

override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
override def dataType: DataType = {
if (children.isEmpty) {
StringType
} else {
super.dataType
}
}

lazy val javaType: String = CodeGenerator.javaType(dataType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -48,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression {

override def dataType: ArrayType = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why CreateArray doesn't extend ComplexTypeMergingExpression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the data type of the CreateArray itself is not the merged type but ArrayType.

ArrayType(
children.headOption.map(_.dataType).getOrElse(StringType),
TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType))
.getOrElse(StringType),
containsNull = children.exists(_.nullable))
}

Expand Down Expand Up @@ -179,11 +180,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(
s"$prettyName expects a positive even number of arguments.")
} else if (keys.map(_.dataType).distinct.length > 1) {
} else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) {
TypeCheckResult.TypeCheckFailure(
"The given keys of function map should all be the same type, but they are " +
keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else if (values.map(_.dataType).distinct.length > 1) {
} else if (!TypeCoercion.haveSameType(values.map(_.dataType))) {
TypeCheckResult.TypeCheckFailure(
"The given values of function map should all be the same type, but they are " +
values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
Expand All @@ -194,8 +195,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {

override def dataType: DataType = {
MapType(
keyType = keys.headOption.map(_.dataType).getOrElse(StringType),
valueType = values.headOption.map(_.dataType).getOrElse(StringType),
keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType))
.getOrElse(StringType),
valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType))
.getOrElse(StringType),
valueContainsNull = values.exists(_.nullable))
}

Expand Down
Loading