-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24734][SQL] Fix type coercions and nullabilities of nested data types of some functions. #21704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-24734][SQL] Fix type coercions and nullabilities of nested data types of some functions. #21704
Changes from 17 commits
d87a8c6
30d5aed
2e624df
fa73b32
da0702b
b31e401
7a838b0
01c9ff3
b2ca587
3d8891e
1fa692a
444383d
3e1f7e4
2c54e38
5f1b865
2ab025f
db254e5
b412f7b
f701242
5115961
e489e8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -184,6 +184,17 @@ object TypeCoercion { | |
| } | ||
| } | ||
|
|
||
| def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = { | ||
| if (types.isEmpty) { | ||
| None | ||
| } else { | ||
| types.tail.foldLeft(Option(types.head)) { | ||
| case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2) | ||
| case _ => None | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Case 2 type widening (see the classdoc comment above for TypeCoercion). | ||
| * | ||
|
|
@@ -259,8 +270,22 @@ object TypeCoercion { | |
| } | ||
| } | ||
|
|
||
| private def haveSameType(exprs: Seq[Expression]): Boolean = | ||
| exprs.map(_.dataType).distinct.length == 1 | ||
| private def haveSameType(exprs: Seq[Expression]): Boolean = { | ||
|
||
| if (exprs.size <= 1) { | ||
| true | ||
| } else { | ||
| val head = exprs.head.dataType | ||
| exprs.tail.forall(_.dataType.sameType(head)) | ||
| } | ||
| } | ||
|
|
||
| private def castIfNotSameType(expr: Expression, dt: DataType): Expression = { | ||
| if (!expr.dataType.sameType(dt)) { | ||
| Cast(expr, dt) | ||
| } else { | ||
| expr | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Widens numeric types and converts strings to numbers when appropriate. | ||
|
|
@@ -525,6 +550,7 @@ object TypeCoercion { | |
| * This ensure that the types for various functions are as expected. | ||
| */ | ||
| object FunctionArgumentConversion extends TypeCoercionRule { | ||
|
|
||
| override protected def coerceTypes( | ||
| plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { | ||
| // Skip nodes who's children have not been resolved yet. | ||
|
|
@@ -533,15 +559,15 @@ object TypeCoercion { | |
| case a @ CreateArray(children) if !haveSameType(children) => | ||
| val types = children.map(_.dataType) | ||
| findWiderCommonType(types) match { | ||
| case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) | ||
| case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it needed? I think optimizer can remove unnecessary cast.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently the optimizer doesn't remove the cast when the difference of the |
||
| case None => a | ||
| } | ||
|
|
||
| case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && | ||
| !haveSameType(children) => | ||
| !c.areInputTypesForMergingEqual => | ||
| val types = children.map(_.dataType) | ||
| findWiderCommonType(types) match { | ||
| case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) | ||
| case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType))) | ||
| case None => c | ||
| } | ||
|
|
||
|
|
@@ -561,10 +587,10 @@ object TypeCoercion { | |
| } | ||
|
|
||
| case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && | ||
| !haveSameType(children) => | ||
| !m.areInputTypesForMergingEqual => | ||
| val types = children.map(_.dataType) | ||
| findWiderCommonType(types) match { | ||
| case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType))) | ||
| case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType))) | ||
| case None => m | ||
| } | ||
|
|
||
|
|
@@ -575,7 +601,7 @@ object TypeCoercion { | |
| } else { | ||
| val types = m.keys.map(_.dataType) | ||
| findWiderCommonType(types) match { | ||
| case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) | ||
| case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType)) | ||
| case None => m.keys | ||
| } | ||
| } | ||
|
|
@@ -585,7 +611,7 @@ object TypeCoercion { | |
| } else { | ||
| val types = m.values.map(_.dataType) | ||
| findWiderCommonType(types) match { | ||
| case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) | ||
| case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType)) | ||
| case None => m.values | ||
| } | ||
| } | ||
|
|
@@ -610,27 +636,27 @@ object TypeCoercion { | |
| // Coalesce should return the first non-null value, which could be any column | ||
| // from the list. So we need to make sure the return type is deterministic and | ||
| // compatible with every child column. | ||
| case c @ Coalesce(es) if !haveSameType(es) => | ||
| case c @ Coalesce(es) if !c.areInputTypesForMergingEqual => | ||
| val types = es.map(_.dataType) | ||
| findWiderCommonType(types) match { | ||
| case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) | ||
| case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType))) | ||
| case None => c | ||
| } | ||
|
|
||
| // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if | ||
| // we need to truncate, but we should not promote one side to string if the other side is | ||
| // string.g | ||
| case g @ Greatest(children) if !haveSameType(children) => | ||
| case g @ Greatest(children) if !g.areInputTypesForMergingEqual => | ||
|
||
| val types = children.map(_.dataType) | ||
| findWiderTypeWithoutStringPromotion(types) match { | ||
| case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) | ||
| case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType))) | ||
| case None => g | ||
| } | ||
|
|
||
| case l @ Least(children) if !haveSameType(children) => | ||
| case l @ Least(children) if !l.areInputTypesForMergingEqual => | ||
| val types = children.map(_.dataType) | ||
| findWiderTypeWithoutStringPromotion(types) match { | ||
| case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) | ||
| case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType))) | ||
| case None => l | ||
| } | ||
|
|
||
|
|
@@ -707,8 +733,8 @@ object TypeCoercion { | |
| // Find tightest common type for If, if the true value and false value have different types. | ||
| case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual => | ||
| findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => | ||
| val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType) | ||
| val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType) | ||
| val newLeft = castIfNotSameType(left, widestType) | ||
| val newRight = castIfNotSameType(right, widestType) | ||
| If(pred, newLeft, newRight) | ||
| }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. | ||
| case If(Literal(null, NullType), left, right) => | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -513,7 +513,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp | |
| > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); | ||
| [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] | ||
| """, since = "2.4.0") | ||
| case class MapConcat(children: Seq[Expression]) extends Expression { | ||
| case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| var funcName = s"function $prettyName" | ||
|
|
@@ -527,14 +527,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression { | |
| } | ||
|
|
||
| override def dataType: MapType = { | ||
| val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption | ||
| .getOrElse(MapType(StringType, StringType)) | ||
| val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) | ||
| .exists(_.valueContainsNull) | ||
| if (dt.valueContainsNull != valueContainsNull) { | ||
| dt.copy(valueContainsNull = valueContainsNull) | ||
| if (children.isEmpty) { | ||
| MapType(StringType, StringType) | ||
| } else { | ||
| dt | ||
| super.dataType.asInstanceOf[MapType] | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -2217,7 +2213,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti | |
| > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); | ||
| | [1,2,3,4,5,6] | ||
| """) | ||
| case class Concat(children: Seq[Expression]) extends Expression { | ||
| case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { | ||
|
|
||
| private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH | ||
|
|
||
|
|
@@ -2238,7 +2234,13 @@ case class Concat(children: Seq[Expression]) extends Expression { | |
| } | ||
| } | ||
|
|
||
| override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) | ||
| override def dataType: DataType = { | ||
| if (children.isEmpty) { | ||
| StringType | ||
| } else { | ||
| super.dataType | ||
| } | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we handle this case in type coercion (analysis phase)?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, E.g.,
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a test to show the wrong nullability.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aha, I see. But, I just have a hunch that
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, that also makes sense. I'm not sure we can remove the simplification, though. cc @gatorsmile @cloud-fan
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This can work, but my point is we should not add the cast to change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, see. In that case, it would be nice to introduce a method that will resolve the output DataType and merges
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGTM
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for the @mn-mikke idea |
||
|
|
||
| lazy val javaType: String = CodeGenerator.javaType(dataType) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,8 +18,8 @@ | |
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} | ||
| import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder | ||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.catalyst.util._ | ||
|
|
@@ -48,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { | |
|
|
||
| override def dataType: ArrayType = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the data type of the |
||
| ArrayType( | ||
| children.headOption.map(_.dataType).getOrElse(StringType), | ||
| TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType)) | ||
| .getOrElse(StringType), | ||
| containsNull = children.exists(_.nullable)) | ||
| } | ||
|
|
||
|
|
@@ -179,11 +180,13 @@ case class CreateMap(children: Seq[Expression]) extends Expression { | |
| if (children.size % 2 != 0) { | ||
| TypeCheckResult.TypeCheckFailure( | ||
| s"$prettyName expects a positive even number of arguments.") | ||
| } else if (keys.map(_.dataType).distinct.length > 1) { | ||
| } else if (keys.length > 1 && | ||
| keys.map(_.dataType).sliding(2, 1).exists { case Seq(t1, t2) => !t1.sameType(t2) }) { | ||
| TypeCheckResult.TypeCheckFailure( | ||
| "The given keys of function map should all be the same type, but they are " + | ||
| keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) | ||
| } else if (values.map(_.dataType).distinct.length > 1) { | ||
| } else if (values.length > 1 && | ||
| values.map(_.dataType).sliding(2, 1).exists { case Seq(t1, t2) => !t1.sameType(t2) }) { | ||
|
||
| TypeCheckResult.TypeCheckFailure( | ||
| "The given values of function map should all be the same type, but they are " + | ||
| values.map(_.dataType.simpleString).mkString("[", ", ", "]")) | ||
|
|
@@ -194,8 +197,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { | |
|
|
||
| override def dataType: DataType = { | ||
| MapType( | ||
| keyType = keys.headOption.map(_.dataType).getOrElse(StringType), | ||
| valueType = values.headOption.map(_.dataType).getOrElse(StringType), | ||
| keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) | ||
| .getOrElse(StringType), | ||
| valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType)) | ||
| .getOrElse(StringType), | ||
| valueContainsNull = values.exists(_.nullable)) | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
Some(types.head)to avoid an extra null check?