diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index c363a5efacde..0732452273a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -247,6 +247,9 @@ class SqlParser extends AbstractSparkSQLParser { | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) } + | termExpression ~ (IN ~ "(" ~> start <~ ")") ^^ { + case e1 ~ e2 => In(e1, Seq(SubqueryExpression(e2))) + } | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { case e1 ~ e2 => In(e1, e2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e4e542562f22..c1f2aab9ae41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -21,6 +21,7 @@ import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -59,6 +60,7 @@ class Analyzer(catalog: Catalog, ResolveGroupingAnalytics :: ResolveSortReferences :: ImplicitGenerate :: + SubQueryExpressions :: ResolveFunctions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: @@ -422,6 +424,108 @@ class Analyzer(catalog: Catalog, Generate(g, join = false, outer = false, None, child) } } + + /** + * Transforms the query which has subquery expressions in where clause to left semi join. + * select T1.x from T1 where T1.x in (select T2.y from T2) transformed to + * select T1.x from T1 left semi join T2 on T1.x = T2.y. + */ + object SubQueryExpressions extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: LogicalPlan if !p.childrenResolved => p + case filter @ Filter(conditions, child) => + val subqueryExprs = conditions.collect { + case In(exp, Seq(SubqueryExpression(subquery))) => (exp, subquery) + } + // Replace subqueries with a dummy true literal since they are evaluated separately now. + val transformedConds = conditions.transform { + case In(_, Seq(SubqueryExpression(_))) => Literal(true) + } + subqueryExprs match { + case Seq() => filter // No subqueries. + case Seq((exp, subquery)) => + createLeftSemiJoin( + child, + exp, + subquery, + transformedConds) + case _ => + throw new TreeNodeException(filter, "Only one SubQuery expression is supported.") + } + } + + /** + * Create LeftSemi join with parent query to the subquery which is mentioned in 'IN' predicate + * And combine the subquery conditions and parent query conditions. + */ + def createLeftSemiJoin(left: LogicalPlan, + value: Expression, + subquery: LogicalPlan, + parentConds: Expression) : LogicalPlan = { + val (transformedPlan, subqueryConds) = transformAndGetConditions(value, subquery) + // Add both parent query conditions and subquery conditions as join conditions + val allPredicates = And(parentConds, subqueryConds) + Join(left, transformedPlan, LeftSemi, Some(allPredicates)) + } + + /** + * Transform the subquery LogicalPlan and add the expressions which are used as filters to the + * projection. And also return filter conditions used in subquery + */ + def transformAndGetConditions(value: Expression, + subquery: LogicalPlan): (LogicalPlan, Expression) = { + val expr = new scala.collection.mutable.ArrayBuffer[Expression]() + // TODO : we only decorelate subqueries in very specific cases like the cases mentioned above + // in documentation. The more complex queries like using of subqueries inside subqueries can + // be supported in future. + val transformedPlan = subquery transform { + case project @ Project(projectList, f @ Filter(condition, child)) => + // Don't support more than one item in select list of subquery + if(projectList.size > 1) { + throw new TreeNodeException( + project, + "SubQuery can contain only one item in Select List") + } + val resolvedChild = ResolveRelations(child) + // Add the expressions to the projections which are used as filters in subquery + val toBeAddedExprs = f.references.filter{a => + resolvedChild.resolve(a.name, resolver) != None && !project.outputSet.contains(a)} + val nameToExprMap = collection.mutable.Map[String, Alias]() + // Create aliases for all projection expressions. + val witAliases = (projectList ++ toBeAddedExprs).zipWithIndex.map { + case (exp, index) => + nameToExprMap.put(exp.name, Alias(exp, s"sqc$index")()) + Alias(exp, s"sqc$index")() + } + // Replace the condition column names with alias names. + val transformedConds = condition.transform { + case a: Attribute if resolvedChild.resolve(a.name, resolver) != None => + nameToExprMap.get(a.name).get.toAttribute + } + // Join the first projection column of subquery to the main query and add as condition + // TODO : We can avoid if the parent condition already has this condition. + expr += EqualTo(value, witAliases(0).toAttribute) + expr += transformedConds + Project(witAliases, child) + case project @ Project(projectList, child) => + // Don't support more than one item in select list of subquery + if(projectList.size > 1) { + throw new TreeNodeException( + project, + "SubQuery can contain only one item in Select List") + } + // Case 1 Uncorelated queries + // Create aliases for all projection expressions. + val witAliases = projectList.zipWithIndex.map{case (x,y) => Alias(x, s"sqc$y")()} + // Take the first projection expression as join condition. + expr += EqualTo(value, witAliases(0).toAttribute) + Project(witAliases, child) + } + // Add alias to Subquery as 'subquery' + (transformedPlan, expr.reduce(And)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubqueryExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubqueryExpression.scala new file mode 100644 index 000000000000..99cbc7cfa4d7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubqueryExpression.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * Evaluates whether `subquery` result contains `value`. + * For example : 'SELECT * FROM src a WHERE a.key in (SELECT b.key FROM src b)' + * @param subquery In the above example 'SELECT b.key FROM src b' is 'subquery' + */ +case class SubqueryExpression(subquery: LogicalPlan) extends Expression { + + type EvaluatedType = Any + def dataType = subquery.output.head.dataType + override def foldable = false + def nullable = true + override def toString = s"SubqueryExpression(${subquery.output.mkString(",")})" + override lazy val resolved = false + def children = Nil + override def eval(input: Row): Any = + sys.error(s"SubqueryExpression eval should not be called since it will be converted" + + " to join query") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 097bf0dd23c8..753e6c4f2f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1049,4 +1049,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } + + test("SPARK-4226 Add support for subqueries in predicates") { + checkAnswer( + sql( + """SELECT a.key FROM testData a + |WHERE a.key in + |(SELECT b.key FROM testData b WHERE b.key in (1))""".stripMargin), Row(1)) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 98263f602e9e..ecabc5fe2862 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1206,6 +1206,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_STRINGLITERALSEQUENCE", strings) => Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) + /* Subquery expressions in where condition */ + case Token("TOK_SUBQUERY_EXPR", + Token("TOK_SUBQUERY_OP", + Token("in", Nil) :: Nil) :: + query :: exprsn :: Nil) => + In(nodeToExpr(exprsn), Seq(SubqueryExpression(nodeToPlan(query)))) // This code is adapted from // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 case ast: ASTNode if numericAstTypes contains ast.getType => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f2bc73bf3bdf..3640a680b7da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -416,4 +416,23 @@ class SQLQuerySuite extends QueryTest { dropTempTable("data") setConf("spark.sql.hive.convertCTAS", originalConf) } + + test("SPARK-4226 Add support for subqueries in predicates- Uncorelated queries") { + checkAnswer( + sql( + """SELECT a.key FROM src a + |WHERE a.key in + |(SELECT b.key FROM src b WHERE b.key in (230))""".stripMargin), + sql("SELECT key FROM src WHERE key in (230)").collect().toSeq) + } + + test("SPARK-4226 Add support for subqueries in predicates- corelated queries") { + checkAnswer( + sql( + """SELECT a.key FROM src a + |WHERE a.key in + |(SELECT b.key FROM src b WHERE b.key in (230)and a.value=b.value)""".stripMargin), + sql("SELECT key FROM src WHERE key in (230)").collect().toSeq) + } + }