Skip to content
Closed
Next Next commit
init import.
  • Loading branch information
viirya committed Mar 17, 2016
commit c2f9b058f6ad243e35518bc1758fe7ef7ac2c25f
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,25 @@ object Unions {
}
}
}

/**
* A pattern that finds the original expression from a sequence of casts.
*/
object Casts {
def unapply(expr: Expression): Option[Attribute] = expr match {
case c: Cast => collectCasts(expr)
case _ => None
}

private def collectCasts(e: Expression): Option[Attribute] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use pattern matching to simplify this block. Perhaps something along the lines of:

  private def collectCasts(expr: Expression): Option[Attribute] = {
    expr match {
      case e: Cast => collectCasts(e.child)
      case e: Attribute => Some(e)
      case _ => None
    }
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Also add the tag @tailrec ?

if (e.isInstanceOf[Cast]) {
collectCasts(e.children(0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.child for better readability?

} else {
if (e.isInstanceOf[Attribute]) {
Some(e.asInstanceOf[Attribute])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any time you see isInstanceOf / asInstanceOf you might want to be using matching instead...

e match {
  case a: Attribute => Some(a)
  case Cast(e, _) => Some(collectCasts(e))
  case _ => None
}

} else {
None
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.Casts
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{DataType, StructType}

Expand All @@ -36,6 +37,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
.union(constructIsNotNullConstraints(constraints))
.filter(constraint =>
constraint.references.nonEmpty && constraint.references.subsetOf(outputSet))
.map(_.transform {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding this here, how about we add something like .map(_.transform { case IsNotNull(Casts(a)) => IsNotNull(a) }) in constructIsNotNullConstraints? Wouldn't that avoid the redundant outputSet check?

case n @ IsNotNull(c) =>
c match {
case Casts(a) if outputSet.contains(a) => IsNotNull(a)
case _ => n
}
})
}

/**
Expand Down