Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Code review
  • Loading branch information
dilipbiswal committed Mar 14, 2017
commit 19cdbb040ccf2e74e1271ca33e6842607c1e0760
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,10 @@ class Analyzer(
case _: EqualTo | _: EqualNullSafe => false
case _ => true
}
// The aggregate expressions are treated in a special way by getOuterReferences. If the
// aggregate expression contains only outer reference attributes then the entire aggregate
// expression is isolated as an OuterReference.
// i.e min(OuterReference(b)) => OuterReference(min(b))
outerReferences ++= getOuterReferences(correlated)

// Project cannot host any correlated expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,30 +382,32 @@ object TypeCoercion {
* Analysis Exception will be raised at the type checking phase.
*/
object InConversion extends Rule[LogicalPlan] {
private def flattenExpr(expr: Expression): Seq[Expression] = {
expr match {
// Multi columns in IN clause is represented as a CreateNamedStruct.
// flatten the named struct to get the list of expressions.
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ In(a, Seq(ListQuery(sub, children, exprId))) if !i.resolved =>
case i @ In(a, Seq(ListQuery(sub, children, exprId)))
if !i.resolved && flattenExpr(a).length == sub.output.length =>
// LHS is the value expression of IN subquery.
val lhs = a match {
// Multi columns in IN clause is represented as a CreateNamedStruct.
// flatten the named struct to get the list of expressions.
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
val lhs = flattenExpr(a)

// RHS is the subquery output.
val rhs = sub.output
require(lhs.length == rhs.length)

val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
findCommonTypeForBinaryComparison(l.dataType, r.dataType) match {
case d @ Some(_) => d
case _ => findTightestCommonType(l.dataType, r.dataType)
}
findCommonTypeForBinaryComparison(l.dataType, r.dataType)
.orElse(findTightestCommonType(l.dataType, r.dataType))
}

// The number of columns/expressions must match between LHS and RHS of an
Expand All @@ -422,14 +424,11 @@ object TypeCoercion {

// Before constructing the In expression, wrap the multi values in LHS
// in a CreatedNamedStruct.
val newLhs = a match {
case cns: CreateNamedStruct =>
val nameValue = cns.nameExprs.zip(castedLhs).flatMap {
case (name, value) => Seq(name, value)
}
CreateNamedStruct(nameValue)
case _ => castedLhs.head
val newLhs = castedLhs match {
case Seq(lhs) => lhs
case _ => CreateStruct(castedLhs)
}

In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
} else {
i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,24 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
val isTypeMismatched = valExprs.zip(sub.output).exists {
case (l, r) => l.dataType != r.dataType

val mismatchedColumns = valExprs.zip(sub.output).flatMap {
case (l, r) if l.dataType != r.dataType =>
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
case _ => None
}
if (isTypeMismatched) {

if (mismatchedColumns.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
s"""
|The data type of one or more elements in the LHS of an IN subquery
|[${valExprs.map(_.dataType).mkString(", ")}]
|is not compatible with the data type of the output of the subquery
|[${sub.output.map(_.dataType).mkString(", ")}].
|The data type of one or more elements in the left hand side of an IN subquery
|is not compatible with the data type of the output of the subquery
|Mismatched columns:
|[${mismatchedColumns.mkString(", ")}]
|Left side:
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|Right side:
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
TypeCheckResult.TypeCheckSuccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ object SubExprUtils extends PredicateHelper {
val outerExpressions = ArrayBuffer.empty[Expression]
conditions foreach { expr =>
expr transformDown {
case a: AggregateExpression if containsOuter(a) =>
case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) =>
val newExpr = stripOuterReference(a)
outerExpressions += newExpr
newExpr
Expand All @@ -210,19 +210,14 @@ object SubExprUtils extends PredicateHelper {
* is removed before returning the predicate to the caller.
*/
def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = {
val correlatedPredicates = ArrayBuffer.empty[Seq[Expression]]
val conditions = plan.collect { case Filter(cond, _) => cond }

// Collect all the expressions that have outer references.
conditions foreach { e =>
conditions.flatMap { e =>
val (correlated, _) = splitConjunctivePredicates(e).partition(containsOuter)
stripOuterReferences(correlated) match {
case Nil => // no-op
case xs =>
correlatedPredicates += xs
case Nil => None
case xs => xs
}
}
correlatedPredicates.flatten
}
}

Expand Down