diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e38114ab3cf2..dabcaff91075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -114,7 +114,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN", StringType) + val StringNaN = Literal("NaN", StringType) def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -122,20 +122,20 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType => - b.makeCopy(Array(b.right, Literal(Double.NaN))) - case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN => - b.makeCopy(Array(Literal(Double.NaN), b.left)) - case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => - b.makeCopy(Array(Literal(Double.NaN), b.left)) + case b @ BinaryExpression(StringNaN, DoubleType(r)) => + b.makeCopy(Array(r, Literal(Double.NaN))) + case b @ BinaryExpression(DoubleType(l), StringNaN) => + b.makeCopy(Array(Literal(Double.NaN), l)) /* Float Conversions */ - case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType => - b.makeCopy(Array(b.right, Literal(Float.NaN))) - case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN => - b.makeCopy(Array(Literal(Float.NaN), b.left)) - case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => - b.makeCopy(Array(Literal(Float.NaN), b.left)) + case b @ BinaryExpression(StringNaN, FloatType(r)) => + b.makeCopy(Array(r, Literal(Float.NaN))) + case b @ BinaryExpression(FloatType(l), StringNaN) => + b.makeCopy(Array(Literal(Float.NaN), l)) + + /* Use float NaN by default to avoid unnecessary type widening */ + case b @ BinaryExpression(l @ StringNaN, StringNaN) => + b.makeCopy(Array(Literal(Float.NaN), l)) } } } @@ -168,9 +168,9 @@ trait HiveTypeCoercion { case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { // When a string is found on one side, make the other side a string too. - case (l, r) if l.dataType == StringType && r.dataType != StringType => + case (StringType(l), r) if r.dataType != StringType => (l, Alias(Cast(r, StringType), r.name)()) - case (l, r) if l.dataType != StringType && r.dataType == StringType => + case (l, StringType(r)) if l.dataType != StringType => (Alias(Cast(l, StringType), l.name)(), r) case (l, r) if l.dataType != r.dataType => @@ -211,12 +211,12 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b: BinaryExpression if b.left.dataType != b.right.dataType => - findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => + case b @ BinaryExpression(l, r) if l.dataType != r.dataType => + findTightestCommonType(l.dataType, r.dataType).map { widestType => val newLeft = - if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) + if (l.dataType == widestType) l else Cast(l, widestType) val newRight = - if (b.right.dataType == widestType) b.right else Cast(b.right, widestType) + if (r.dataType == widestType) r else Cast(r, widestType) b.makeCopy(Array(newLeft, newRight)) }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. } @@ -231,51 +231,50 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a: BinaryArithmetic if a.left.dataType == StringType => - a.makeCopy(Array(Cast(a.left, DoubleType), a.right)) - case a: BinaryArithmetic if a.right.dataType == StringType => - a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) + case a @ BinaryArithmetic(StringType(l), r) => + a.makeCopy(Array(Cast(l, DoubleType), r)) + case a @ BinaryArithmetic(l, StringType(r)) => + a.makeCopy(Array(l, Cast(r, DoubleType))) // we should cast all timestamp/date/string compare into string compare - case p: BinaryPredicate if p.left.dataType == StringType - && p.right.dataType == DateType => - p.makeCopy(Array(p.left, Cast(p.right, StringType))) - case p: BinaryPredicate if p.left.dataType == DateType - && p.right.dataType == StringType => - p.makeCopy(Array(Cast(p.left, StringType), p.right)) - case p: BinaryPredicate if p.left.dataType == StringType - && p.right.dataType == TimestampType => - p.makeCopy(Array(p.left, Cast(p.right, StringType))) - case p: BinaryPredicate if p.left.dataType == TimestampType - && p.right.dataType == StringType => - p.makeCopy(Array(Cast(p.left, StringType), p.right)) - case p: BinaryPredicate if p.left.dataType == TimestampType - && p.right.dataType == DateType => - p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - case p: BinaryPredicate if p.left.dataType == DateType - && p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - - case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType => - p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) - case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) - - case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) => + case p @ BinaryPredicate(StringType(l), DateType(r)) => + p.makeCopy(Array(l, Cast(r, StringType))) + case p @ BinaryPredicate(DateType(l), StringType(r)) => + p.makeCopy(Array(Cast(l, StringType), r)) + case p @ BinaryPredicate(TimestampType(l), DateType(r)) => + p.makeCopy(Array(Cast(l, StringType), Cast(r, StringType))) + case p @ BinaryPredicate(DateType(l), TimestampType(r)) => + p.makeCopy(Array(Cast(l, StringType), Cast(r, StringType))) + case p @ BinaryPredicate(StringType(l), TimestampType(r)) => + p.makeCopy(Array(Cast(l, TimestampType), r)) + case p @ BinaryPredicate(TimestampType(l), StringType(r)) => + p.makeCopy(Array(l, Cast(r, TimestampType))) + + case p @ BinaryPredicate(StringType(l), r) if r.dataType != StringType => + p.makeCopy(Array(Cast(l, DoubleType), r)) + case p @ BinaryPredicate(l, StringType(r)) if l.dataType != StringType => + p.makeCopy(Array(l, Cast(r, DoubleType))) + + case i @ In(DateType(a), b) if b.forall(_.dataType == StringType) => i.makeCopy(Array(Cast(a, StringType), b)) - case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => + case i @ In(TimestampType(a), b) if b.forall(_.dataType == StringType) => i.makeCopy(Array(Cast(a, StringType), b)) - case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) => + case i @ In(DateType(a), b) if b.forall(_.dataType == TimestampType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) => + case i @ In(TimestampType(a), b) if b.forall(_.dataType == DateType) => + i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) + case i @ In(DateType(a), b) if b.forall(_.dataType == StringType) => + i.makeCopy(Array(Cast(a, StringType), b)) + case i @ In(TimestampType(a), b) if b.forall(_.dataType == StringType) => + i.makeCopy(Array(a, b.map(Cast(_,TimestampType)))) + case i @ In(DateType(a), b) if b.forall(_.dataType == TimestampType) => + i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) + case i @ In(TimestampType(a), b) if b.forall(_.dataType == DateType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case Sum(e) if e.dataType == StringType => - Sum(Cast(e, DoubleType)) - case Average(e) if e.dataType == StringType => - Average(Cast(e, DoubleType)) - case Sqrt(e) if e.dataType == StringType => - Sqrt(Cast(e, DoubleType)) + case Sum(StringType(e)) => Sum(Cast(e, DoubleType)) + case Average(StringType(e)) => Average(Cast(e, DoubleType)) + case Sqrt(StringType(e)) => Sqrt(Cast(e, DoubleType)) } } @@ -395,19 +394,18 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e // Hive treats (true = 1) as true and (false = 0) as true. - case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l - case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r - case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) - case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) + case EqualTo(BooleanType(l), r) if trueValues.contains(r) => l + case EqualTo(l, BooleanType(r)) if trueValues.contains(l) => r + case EqualTo(BooleanType(l), r) if falseValues.contains(r) => Not(l) + case EqualTo(l, BooleanType(r)) if falseValues.contains(l) => Not(r) // No need to change other EqualTo operators as that actually makes sense for boolean types. case e: EqualTo => e // No need to change the EqualNullSafe operators, too case e: EqualNullSafe => e // Otherwise turn them to Byte types so that there exists and ordering. - case p: BinaryComparison - if p.left.dataType == BooleanType && p.right.dataType == BooleanType => - p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) + case p @ BinaryComparison(BooleanType(l), BooleanType(r)) => + p.makeCopy(Array(Cast(l, ByteType), Cast(r, ByteType))) } } @@ -421,18 +419,18 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e // Skip if the type is boolean type already. Note that this extra cast should be removed // by optimizer.SimplifyCasts. - case Cast(e, BooleanType) if e.dataType == BooleanType => e + case Cast(BooleanType(e), BooleanType) => e // DateType should be null if be cast to boolean. - case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType) + case Cast(DateType(e), BooleanType) => Cast(e, BooleanType) // If the data type is not boolean and is being cast boolean, turn it into a comparison // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) // Stringify boolean if casting to StringType. // TODO Ensure true/false string letter casing is consistent with Hive in all cases. - case Cast(e, StringType) if e.dataType == BooleanType => + case Cast(BooleanType(e), StringType) => If(e, Literal("true"), Literal("false")) // Turn true into 1, and false into 0 if casting boolean into other types. - case Cast(e, dataType) if e.dataType == BooleanType => + case Cast(BooleanType(e), dataType) => Cast(If(e, Literal(1), Literal(0)), dataType) } } @@ -447,7 +445,7 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case Cast(e @ StringType(), t: IntegralType) => + case Cast(StringType(e), t: IntegralType) => Cast(Cast(e, DecimalType.Unlimited), t) } } @@ -468,20 +466,20 @@ trait HiveTypeCoercion { children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. - case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. - case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) - case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) + case s @ Sum(DecimalType(e)) => s // Decimal is already the biggest. + case Sum(IntegralType(e)) if e.dataType != LongType => Sum(Cast(e, LongType)) + case Sum(FractionalType(e)) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest. - case SumDistinct(e @ IntegralType()) if e.dataType != LongType => + case s @ SumDistinct(DecimalType(e)) => s // Decimal is already the biggest. + case SumDistinct(IntegralType(e)) if e.dataType != LongType => SumDistinct(Cast(e, LongType)) - case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType => + case SumDistinct(FractionalType(e)) if e.dataType != DoubleType => SumDistinct(Cast(e, DoubleType)) - case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. - case Average(e @ IntegralType()) if e.dataType != LongType => + case s @ Average(DecimalType(e)) => s // Decimal is already the biggest. + case Average(IntegralType(e)) if e.dataType != LongType => Average(Cast(e, LongType)) - case Average(e @ FractionalType()) if e.dataType != DoubleType => + case Average(FractionalType(e)) if e.dataType != DoubleType => Average(Cast(e, DoubleType)) // Hive lets you do aggregation of timestamps... for some reason @@ -503,10 +501,8 @@ trait HiveTypeCoercion { case d: Divide if d.resolved && d.dataType == DoubleType => d case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => - Divide(l, Cast(r, DecimalType.Unlimited)) - case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => - Divide(Cast(l, DecimalType.Unlimited), r) + case Divide(DecimalType(l), r) => Divide(l, Cast(r, DecimalType.Unlimited)) + case Divide(l, DecimalType(r)) => Divide(Cast(l, DecimalType.Unlimited), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } @@ -519,7 +515,7 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) => + case cw @ CaseWhen(branches) if !cw.resolved && branches.forall(_.resolved) => val valueTypes = branches.sliding(2, 2).map { case Seq(_, value) => value.dataType case Seq(elseVal) => elseVal.dataType @@ -547,5 +543,4 @@ trait HiveTypeCoercion { } } } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 39b120e8de48..66dd400fd00f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -227,6 +227,10 @@ abstract class Expression extends TreeNode[Expression] { } } +object BinaryExpression { + def unapply(a: BinaryExpression): Option[(Expression, Expression)] = Some((a.left, a.right)) +} + abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => @@ -243,6 +247,4 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => - - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8574cabc4352..2c2ddac4579e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.types._ -import scala.math.pow case class UnaryMinus(child: Expression) extends UnaryExpression { type EvaluatedType = Any @@ -43,10 +42,14 @@ case class Sqrt(child: Expression) extends UnaryExpression { override def toString = s"SQRT($child)" override def eval(input: Row): Any = { - n1(child, input, ((na,a) => math.sqrt(na.toDouble(a)))) + n1(child, input, (na, a) => math.sqrt(na.toDouble(a))) } } +object BinaryArithmetic { + def unapply(a: BinaryArithmetic): Option[(Expression, Expression)] = Some((a.left, a.right)) +} + abstract class BinaryArithmetic extends BinaryExpression { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 67f8d411b6bb..dbf5d4822a22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -235,7 +235,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val $primitiveTerm: ${termForType(dataType)} = $value """.children - case Cast(e @ BinaryType(), StringType) => + case Cast(BinaryType(e), StringType) => val eval = expressionEvaluator(e) eval.code ++ q""" @@ -247,16 +247,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children - case Cast(child @ NumericType(), IntegerType) => + case Cast(NumericType(child), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) - case Cast(child @ NumericType(), LongType) => + case Cast(NumericType(child), LongType) => child.castOrNull(c => q"$c.toLong", LongType) - case Cast(child @ NumericType(), DoubleType) => + case Cast(NumericType(child), DoubleType) => child.castOrNull(c => q"$c.toDouble", DoubleType) - case Cast(child @ NumericType(), FloatType) => + case Cast(NumericType(child), FloatType) => child.castOrNull(c => q"$c.toFloat", IntegerType) // Special handling required for timestamps in hive test cases since the toString function @@ -301,13 +301,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin """.children */ - case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => + case GreaterThan(NumericType(e1), NumericType(e2)) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } - case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + case GreaterThanOrEqual(NumericType(e1), NumericType(e2)) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } - case LessThan(e1 @ NumericType(), e2 @ NumericType()) => + case LessThan(NumericType(e1), NumericType(e2)) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } - case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + case LessThanOrEqual(NumericType(e1), NumericType(e2)) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } case And(e1, e2) => @@ -546,7 +546,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { - case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" + case NativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } } @@ -557,7 +557,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ordinal: Int, value: TermName) = { dataType match { - case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case NativeType(dt) => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1e22b2d03c67..994661942341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -62,6 +62,10 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } +object BinaryPredicate { + def unapply(a: BinaryPredicate): Option[(Expression, Expression)] = Some((a.left, a.right)) +} + abstract class BinaryPredicate extends BinaryExpression with Predicate { self: Product => def nullable = left.nullable || right.nullable @@ -99,7 +103,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) +case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) extends Predicate { def children = child @@ -156,6 +160,10 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate { } } +object BinaryComparison { + def unapply(a: BinaryComparison): Option[(Expression, Expression)] = Some((a.left, a.right)) +} + abstract class BinaryComparison extends BinaryPredicate { self: Product => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 8dda0b182805..c4bfe44dee69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -167,10 +167,7 @@ object DataType { abstract class DataType { /** Matches any expression that evaluates to this DataType */ - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType == this => true - case _ => false - } + def unapply[T <: Expression](a: T): Option[T] = if (a.dataType == this) Some(a) else None def isPrimitive: Boolean = false @@ -189,7 +186,7 @@ object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) - def unapply(dt: DataType): Boolean = all.contains(dt) + def unapply[T <: DataType](dt: T): Option[T] = if (all.contains(dt)) Some(dt) else None val defaultSizeOf: Map[NativeType, Int] = Map( IntegerType -> 4, @@ -288,15 +285,14 @@ abstract class NumericType extends NativeType with PrimitiveType { } object NumericType { - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] + def unapply[T <: Expression](e: T): Option[T] = + if (e.dataType.isInstanceOf[NumericType]) Some(e) else None } /** Matcher for any expressions that evaluate to [[IntegralType]]s */ object IntegralType { - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType.isInstanceOf[IntegralType] => true - case _ => false - } + def unapply[T <: Expression](a: T): Option[T] = + if (a.dataType.isInstanceOf[IntegralType]) Some(a) else None } abstract class IntegralType extends NumericType { @@ -337,10 +333,8 @@ case object ByteType extends IntegralType { /** Matcher for any expressions that evaluate to [[FractionalType]]s */ object FractionalType { - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType.isInstanceOf[FractionalType] => true - case _ => false - } + def unapply(a: Expression): Option[Expression] = + if (a.dataType.isInstanceOf[FractionalType]) Some(a) else None } abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType]