Skip to content
Next Next commit
Supporting subqueries inside where 'in' clause
Conflicts:
	sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
  • Loading branch information
ravipesala committed Feb 27, 2015
commit 0a41e91100c1fbf2a4851c5df7b475f1b1d15dad
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ class SqlParser extends AbstractSparkSQLParser {
| termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) }
| termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) }
| termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) }
| termExpression ~ (IN ~ "(" ~> start <~ ")") ^^ {
case e1 ~ e2 => SubqueryExpression(e1, e2)
}
| termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ {
case e1 ~ e2 => In(e1, e2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Analyzer(catalog: Catalog,
ResolveGroupingAnalytics ::
ResolveSortReferences ::
ImplicitGenerate ::
SubQueryExpressions ::
ResolveFunctions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
Expand Down Expand Up @@ -422,6 +423,64 @@ class Analyzer(catalog: Catalog,
Generate(g, join = false, outer = false, None, child)
}
}

/**
* Transforms the query which has subquery expressions in where clause to left semi join.
* select T1.x from T1 where T1.x in (select T2.y from T2) transformed to
* select T1.x from T1 left semi join T2 on T1.x = T2.y.
*/
object SubQueryExpressions extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(conditions, child) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not going to handle the non Subquery case here right? how about

case filter @ Filter(In(expr, SubqueryExpression(subquery)), child) =>

val subqueryExprs = new scala.collection.mutable.ArrayBuffer[SubqueryExpression]()
val nonSubQueryConds = new scala.collection.mutable.ArrayBuffer[Expression]()
val transformedConds = conditions.transform{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space before {.

Also I would consider doing this in two steps to avoid depending on transform for side effects: a collect to get the list and then a transform to replace with true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Done in two steps.

// Replace with dummy
case s @ SubqueryExpression(exp,subquery) =>
subqueryExprs += s
Literal(true)
}
if(subqueryExprs.size > 0) {
val subqueryExpr = subqueryExprs.remove(0)
val firstJoin = createLeftSemiJoin(
child, subqueryExpr.exp, subqueryExpr.child, transformedConds)
subqueryExprs.foldLeft(firstJoin){case(fj, sq) =>
createLeftSemiJoin(fj, sq.exp, sq.child)}
} else {
filter
}
}

def createLeftSemiJoin(left: LogicalPlan,
expression: Expression, subquery: LogicalPlan,
parentConds: Expression = null) : LogicalPlan = {
val (transformedPlan, subqueryConds) = transformAndGetConditions(
expression, subquery)
// Unify the parent query conditions and subquery conditions and add these as j0in conditions
val unifyConds = if (parentConds != null) And(parentConds, subqueryConds) else subqueryConds
Join(left, transformedPlan, LeftSemi, Some(unifyConds))
}

def transformAndGetConditions(expression: Expression,
plan: LogicalPlan): (LogicalPlan, Expression) = {
val expr = new scala.collection.mutable.ArrayBuffer[Expression]()
val transformedPlan = plan transform {
case project @ Project(projectList, f @ Filter(condition, child)) =>
expr += EqualTo(expression, projectList(0).asInstanceOf[Expression])
expr += condition
val resolvedChild = ResolveRelations(child)
// Add the expressions to the projections which are used as filters in subquery
val toBeAddedExprs = f.references.filter(
a=>resolvedChild.resolve(a.name, resolver) != None && !projectList.contains(a))
Project(projectList ++ toBeAddedExprs, child)
case project @ Project(projectList, child) =>
expr += EqualTo(expression, projectList(0).asInstanceOf[Expression])
project
}
(transformedPlan, expr.reduce(And(_, _)))
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1049,4 +1049,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
rdd.toDF().registerTempTable("distinctData")
checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2))
}

test("SPARK-4226 Add support for subqueries in predicates") {
checkAnswer(
sql(
"""SELECT a.key FROM testData a
|WHERE a.key in
|(SELECT b.key FROM testData b WHERE b.key in (1))""".stripMargin), 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_STRINGLITERALSEQUENCE", strings) =>
Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString)

/* Subquery expressions in where condition */
case Token("TOK_SUBQUERY_EXPR",
Token("TOK_SUBQUERY_OP",
Token("in", Nil) :: Nil) ::
query :: exprsn :: Nil) =>
SubqueryExpression(nodeToExpr(exprsn),nodeToPlan(query))
// This code is adapted from
// /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223
case ast: ASTNode if numericAstTypes contains ast.getType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,23 @@ class SQLQuerySuite extends QueryTest {
dropTempTable("data")
setConf("spark.sql.hive.convertCTAS", originalConf)
}

test("SPARK-4226 Add support for subqueries in predicates- Uncorelated queries") {
checkAnswer(
sql(
"""SELECT a.key FROM src a
|WHERE a.key in
|(SELECT b.key FROM src b WHERE b.key in (230))""".stripMargin),
sql("SELECT key FROM src WHERE key in (230)").collect().toSeq)
}

test("SPARK-4226 Add support for subqueries in predicates- corelated queries") {
checkAnswer(
sql(
"""SELECT a.key FROM src a
|WHERE a.key in
|(SELECT b.key FROM src b WHERE b.key in (230)and a.value=b.value)""".stripMargin),
sql("SELECT key FROM src WHERE key in (230)").collect().toSeq)
}

}