Skip to content
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ object TypeCoercion {
}
}

def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = {
if (types.isEmpty) {
None
} else {
types.tail.foldLeft(Option(types.head)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Some(types.head) to avoid an extra null check?

case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2)
case _ => None
}
}
}

/**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
Expand Down Expand Up @@ -259,8 +270,22 @@ object TypeCoercion {
}
}

private def haveSameType(exprs: Seq[Expression]): Boolean =
exprs.map(_.dataType).distinct.length == 1
private def haveSameType(exprs: Seq[Expression]): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicated with ComplexTypeMergingExpression.areInputTypesForMergingEqual, can we unify them? We can

  1. remove hasSameType. Any expression that needs to do this check should extend ComplexTypeMergingExpression and the TypeCoercion rule should call areInputTypesForMergingEqual.
  2. remove areInputTypesForMergingEqual. ComplexTypeMergingExpression should only define the list of data types that need to be merged, and TypeCoercion rule should call hasSameType(e.inputTypesForMerging)

Copy link
Member Author

@ueshin ueshin Jul 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have CreateArray and CreateMap, we can't make all such expressions ComplexTypeMergingExpression.
I'd apply 2) approach.

if (exprs.size <= 1) {
true
} else {
val head = exprs.head.dataType
exprs.tail.forall(_.dataType.sameType(head))
}
}

private def castIfNotSameType(expr: Expression, dt: DataType): Expression = {
if (!expr.dataType.sameType(dt)) {
Cast(expr, dt)
} else {
expr
}
}

/**
* Widens numeric types and converts strings to numbers when appropriate.
Expand Down Expand Up @@ -525,6 +550,7 @@ object TypeCoercion {
* This ensure that the types for various functions are as expected.
*/
object FunctionArgumentConversion extends TypeCoercionRule {

override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
Expand All @@ -533,15 +559,15 @@ object TypeCoercion {
case a @ CreateArray(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it needed? I think optimizer can remove unnecessary cast.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the optimizer doesn't remove the cast when the difference of the finalDataType is only the nullabilities of nested types.

case None => a
}

case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
!c.areInputTypesForMergingEqual =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType)))
case None => c
}

Expand All @@ -561,10 +587,10 @@ object TypeCoercion {
}

case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
!m.areInputTypesForMergingEqual =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType)))
case None => m
}

Expand All @@ -575,7 +601,7 @@ object TypeCoercion {
} else {
val types = m.keys.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType))
case None => m.keys
}
}
Expand All @@ -585,7 +611,7 @@ object TypeCoercion {
} else {
val types = m.values.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType))
case None => m.values
}
}
Expand All @@ -610,27 +636,27 @@ object TypeCoercion {
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case c @ Coalesce(es) if !haveSameType(es) =>
case c @ Coalesce(es) if !c.areInputTypesForMergingEqual =>
val types = es.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType)))
case None => c
}

// When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
// we need to truncate, but we should not promote one side to string if the other side is
// string.g
case g @ Greatest(children) if !haveSameType(children) =>
case g @ Greatest(children) if !g.areInputTypesForMergingEqual =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are not just concat, can you update the PR title?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated PR title and description.

val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType)))
case None => g
}

case l @ Least(children) if !haveSameType(children) =>
case l @ Least(children) if !l.areInputTypesForMergingEqual =>
val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType)))
case None => l
}

Expand Down Expand Up @@ -707,8 +733,8 @@ object TypeCoercion {
// Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual =>
findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType)
val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType)
val newLeft = castIfNotSameType(left, widestType)
val newRight = castIfNotSameType(right, widestType)
If(pred, newLeft, newRight)
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
case If(Literal(null, NullType), left, right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
Expand Down Expand Up @@ -514,7 +514,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
> SELECT _FUNC_(10, 9, 2, 4, 3);
2
""")
case class Least(children: Seq[Expression]) extends Expression {
case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression {

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand All @@ -525,7 +525,7 @@ case class Least(children: Seq[Expression]) extends Expression {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
} else if (!areInputTypesForMergingEqual) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
Expand All @@ -534,8 +534,6 @@ case class Least(children: Seq[Expression]) extends Expression {
}
}

override def dataType: DataType = children.head.dataType

override def eval(input: InternalRow): Any = {
children.foldLeft[Any](null)((r, c) => {
val evalc = c.eval(input)
Expand Down Expand Up @@ -589,7 +587,7 @@ case class Least(children: Seq[Expression]) extends Expression {
> SELECT _FUNC_(10, 9, 2, 4, 3);
10
""")
case class Greatest(children: Seq[Expression]) extends Expression {
case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression {

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand All @@ -600,7 +598,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
} else if (!areInputTypesForMergingEqual) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
Expand All @@ -609,8 +607,6 @@ case class Greatest(children: Seq[Expression]) extends Expression {
}
}

override def dataType: DataType = children.head.dataType

override def eval(input: InternalRow): Any = {
children.foldLeft[Any](null)((r, c) => {
val evalc = c.eval(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
[[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]]
""", since = "2.4.0")
case class MapConcat(children: Seq[Expression]) extends Expression {
case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression {

override def checkInputDataTypes(): TypeCheckResult = {
var funcName = s"function $prettyName"
Expand All @@ -527,14 +527,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression {
}

override def dataType: MapType = {
val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption
.getOrElse(MapType(StringType, StringType))
val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType])
.exists(_.valueContainsNull)
if (dt.valueContainsNull != valueContainsNull) {
dt.copy(valueContainsNull = valueContainsNull)
if (children.isEmpty) {
MapType(StringType, StringType)
} else {
dt
super.dataType.asInstanceOf[MapType]
}
}

Expand Down Expand Up @@ -2217,7 +2213,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
| [1,2,3,4,5,6]
""")
case class Concat(children: Seq[Expression]) extends Expression {
case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression {

private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

Expand All @@ -2238,7 +2234,13 @@ case class Concat(children: Seq[Expression]) extends Expression {
}
}

override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
override def dataType: DataType = {
if (children.isEmpty) {
StringType
} else {
super.dataType
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we handle this case in type coercion (analysis phase)?

Copy link
Member Author

@ueshin ueshin Jul 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, Concat for array type has the type coercion to add casts to make all children the same type, but we also have the optimization SimplifyCasts to remove unnecessary casts which might remove casts from arrays not containing null to arrays containing null (optimizer/expressions.scala#L611).

E.g., concat(array(1,2,3), array(4,null,6)) might generate a wrong data type during the execution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test to show the wrong nullability.

Copy link
Member

@maropu maropu Jul 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I see. But, I just have a hunch that SimplifyCasts cannot simplify array casts in some cases?, e.g., this concat case. Since we basically cannot change plan semantics in optimization phase, I feel a little weird about this simplification. Anyway, I'm ok with your approach because I can't find a better & simpler way to solve this in analysis phase... Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that also makes sense. I'm not sure we can remove the simplification, though. cc @gatorsmile @cloud-fan

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, Coalesce should be fixed, and Least and Greatest are also suspicious.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about changing SimplifyCasts rule to start replacing Cast with a new dummy cast expression that will hold only a target data type and won't perform any casting?

This can work, but my point is we should not add the cast to change containsNull, as it may not match the underlying child expression and generates unnecessary null check code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, see. In that case, it would be nice to introduce a method that will resolve the output DataType and merges nullable/containNull flags of non-primitive types recursively for such expressions. For the most cases we could encapsulate the function into a new trait. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for the @mn-mikke idea


lazy val javaType: String = CodeGenerator.javaType(dataType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -48,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression {

override def dataType: ArrayType = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why CreateArray doesn't extend ComplexTypeMergingExpression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the data type of the CreateArray itself is not the merged type but ArrayType.

ArrayType(
children.headOption.map(_.dataType).getOrElse(StringType),
TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType))
.getOrElse(StringType),
containsNull = children.exists(_.nullable))
}

Expand Down Expand Up @@ -179,11 +180,13 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(
s"$prettyName expects a positive even number of arguments.")
} else if (keys.map(_.dataType).distinct.length > 1) {
} else if (keys.length > 1 &&
keys.map(_.dataType).sliding(2, 1).exists { case Seq(t1, t2) => !t1.sameType(t2) }) {
TypeCheckResult.TypeCheckFailure(
"The given keys of function map should all be the same type, but they are " +
keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else if (values.map(_.dataType).distinct.length > 1) {
} else if (values.length > 1 &&
values.map(_.dataType).sliding(2, 1).exists { case Seq(t1, t2) => !t1.sameType(t2) }) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checks for keys and values are very similar. Would it be possible to separate the common logic into a private method?

TypeCheckResult.TypeCheckFailure(
"The given values of function map should all be the same type, but they are " +
values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
Expand All @@ -194,8 +197,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {

override def dataType: DataType = {
MapType(
keyType = keys.headOption.map(_.dataType).getOrElse(StringType),
valueType = values.headOption.map(_.dataType).getOrElse(StringType),
keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType))
.getOrElse(StringType),
valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType))
.getOrElse(StringType),
valueContainsNull = values.exists(_.nullable))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ object Literal {
case map: MapType => create(Map(), map)
case struct: StructType =>
create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct)
case udt: UserDefinedType[_] => default(udt.sqlType)
case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt)
case other =>
throw new RuntimeException(s"no default for type $dataType")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
Expand All @@ -44,7 +44,7 @@ import org.apache.spark.sql.types._
1
""")
// scalastyle:on line.size.limit
case class Coalesce(children: Seq[Expression]) extends Expression {
case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpression {

/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
override def nullable: Boolean = children.forall(_.nullable)
Expand All @@ -61,8 +61,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
}

override def dataType: DataType = children.head.dataType

override def eval(input: InternalRow): Any = {
var result: Any = null
val childIterator = children.iterator
Expand Down
Loading