Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add error message and tests
  • Loading branch information
cloud-fan committed Jun 1, 2015
commit 69ca3feec42d54b3845c49d95b9890a0a40c4327
Original file line number Diff line number Diff line change
Expand Up @@ -619,18 +619,13 @@ trait HiveTypeCoercion {
*/
object Division extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
// Skip nodes who's children have not been resolved yet or input types do not match.
case e if !e.childrenResolved || e.checkInputDataTypes().hasError => e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed offline, maybe it'd make sense to have resolved also call the data check function.


// Decimal and Double remain the same
case d: Divide if d.resolved && d.dataType == DoubleType => d
case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d

case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
Divide(l, Cast(r, DecimalType.Unlimited))
case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
Divide(Cast(l, DecimalType.Unlimited), r)

case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ abstract class Expression extends TreeNode[Expression] {
}

/**
* todo
* Check the input data types, returns `TypeCheckResult.success` if it's valid,
* or return a `TypeCheckResult` with an error message if invalid.
*/
def checkInputDataTypes: TypeCheckResult = TypeCheckResult.success
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ abstract class UnaryArithmetic extends UnaryExpression {
override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType

override def checkInputDataTypes: TypeCheckResult = {
if (TypeUtils.validForNumericExpr(child.dataType)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
Expand All @@ -52,6 +44,9 @@ abstract class UnaryArithmetic extends UnaryExpression {
case class UnaryMinus(child: Expression) extends UnaryArithmetic {
override def toString: String = s"-$child"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "operator -")

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
Expand All @@ -62,6 +57,9 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
override def nullable: Boolean = true
override def toString: String = s"SQRT($child)"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function sqrt")

private lazy val numeric = TypeUtils.getNumeric(child.dataType)

protected override def evalInternal(evalE: Any) = {
Expand All @@ -77,6 +75,9 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
case class Abs(child: Expression) extends UnaryArithmetic {
override def toString: String = s"Abs($child)"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function abs")

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def evalInternal(evalE: Any) = numeric.abs(evalE)
Expand All @@ -87,10 +88,10 @@ abstract class BinaryArithmetic extends BinaryExpression {

override def dataType: DataType = left.dataType

override def checkInputDataTypes: TypeCheckResult = {
override def checkInputDataTypes(): TypeCheckResult = {
if (left.dataType != right.dataType) {
TypeCheckResult.fail(
s"differing types in BinaryArithmetics -- ${left.dataType}, ${right.dataType}")
s"differing types in BinaryArithmetic, ${left.dataType} != ${right.dataType}")
} else {
checkTypesInternal(dataType)
}
Expand Down Expand Up @@ -123,13 +124,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
// for `Add` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

Expand All @@ -143,13 +139,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
// for `Subtract` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

Expand All @@ -163,13 +154,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
// for `Multiply` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

Expand All @@ -184,13 +170,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
// for `Divide` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
Expand Down Expand Up @@ -220,13 +201,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
// for `Remainder` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val integral = dataType match {
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
Expand Down Expand Up @@ -254,13 +230,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "&"

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForBitwiseExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val and: (Any, Any) => Any = dataType match {
case ByteType =>
Expand All @@ -282,13 +253,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "|"

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForBitwiseExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val or: (Any, Any) => Any = dataType match {
case ByteType =>
Expand All @@ -310,13 +276,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "^"

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForBitwiseExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val xor: (Any, Any) => Any = dataType match {
case ByteType =>
Expand All @@ -338,13 +299,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
case class BitwiseNot(child: Expression) extends UnaryArithmetic {
override def toString: String = s"~$child"

override def checkInputDataTypes: TypeCheckResult = {
if (TypeUtils.validForBitwiseExpr(dataType)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~")

private lazy val not: (Any) => Any = dataType match {
case ByteType =>
Expand All @@ -363,13 +319,8 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def nullable: Boolean = left.nullable && right.nullable

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForOrderingExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(t, "function maxOf")

private lazy val ordering = TypeUtils.getOrdering(dataType)

Expand All @@ -395,13 +346,8 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def nullable: Boolean = left.nullable && right.nullable

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForOrderingExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(t, "function minOf")

private lazy val ordering = TypeUtils.getOrdering(dataType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def foldable: Boolean = child.foldable
override def nullable: Boolean = true
override def toString: String = s"$name($child)"

Expand Down
Loading