diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 93e69d409cb9..e0db92c2a47e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -115,6 +115,7 @@ object UnionPushdown extends Rule[LogicalPlan] { * - Aggregate * - Project <- Join * - LeftSemiJoin + * - Generate * - Collapse adjacent projections, performing alias substitution. */ object ColumnPruning extends Rule[LogicalPlan] { @@ -171,6 +172,10 @@ object ColumnPruning extends Rule[LogicalPlan] { Project(substitutedProjection, child) + // add a project which is blocked by Generate + case p @ pushBelowGenerate(newChild) => + p.copy(child = newChild) + // Eliminate no-op Projects case Project(projectList, child) if child.output == projectList => child } @@ -182,6 +187,32 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { c } + + object pushBelowGenerate { + // because generate block project operate, it need to insert a project below generate with all + // references + def collectRefersUntilGen(refers: AttributeSet, plan: LogicalPlan): LogicalPlan = { + val collectRefers = refers ++ plan.references + plan match { + case filter @ Filter(_, c) => + val newChild = collectRefersUntilGen(collectRefers, c) + // null indicate child is not changed + if (newChild != null) filter.copy(child = newChild) else null + case gen @ Generate(_, _, _, _, c) => + if ((c.outputSet -- collectRefers.filter(c.outputSet.contains)).nonEmpty) { + gen.copy(child = Project(collectRefers.filter(c.outputSet.contains).toSeq, c)) + } else { + null + } + case _ => null + } + } + + def unapply(plan: Project): Option[LogicalPlan] = { + val newChild = collectRefersUntilGen(plan.references, plan.child) + if (newChild != null) Some(newChild) else None + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8633e06093cf..88afb965e863 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -55,6 +55,7 @@ case class Generate( child: LogicalPlan) extends UnaryNode { + protected def generatorOutput: Seq[Attribute] = { val output = alias .map(a => generator.output.map(_.withQualifiers(a :: Nil))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 55c6766520a1..70e533fa5b9c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -502,7 +502,11 @@ class FilterPushdownSuite extends PlanTest { .where(('c > 6) || ('b > 5)).analyze } val optimized = Optimize(originalQuery) - - comparePlans(optimized, originalQuery) + val correctAnswer = { + testRelationWithArrayType + .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .where(('c > 6) || ('b > 5)).analyze + } + comparePlans(optimized, correctAnswer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 12271048bb39..4bb2d14f63db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -56,7 +56,8 @@ case class Generate( val boundGenerator = BindReferences.bindReference(generator, child.output) override def execute(): RDD[Row] = { - if (join) { + // #SPARK-6489 do not join when the child has no output + if (join && child.output.nonEmpty) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) // Used to produce rows with no matches when outer = true. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a3497eadd67f..747dcc0762cb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -367,7 +367,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1) SELECT title, air_date, doctor FROM episodes """.cmd - ) + ), + TestTable("person", + ("CREATE TABLE person(name string, age int, data array) " + + "ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' " + + "COLLECTION ITEMS TERMINATED BY ':'").cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/person.txt")}' INTO TABLE person".cmd + ) ) hiveQTestUtilTables.foreach(registerTestTable) diff --git a/sql/hive/src/test/resources/data/files/person.txt b/sql/hive/src/test/resources/data/files/person.txt new file mode 100644 index 000000000000..7241707e4b63 --- /dev/null +++ b/sql/hive/src/test/resources/data/files/person.txt @@ -0,0 +1,5 @@ +A, 20, 10:12:19 +B, 25, 7:8:4 +C, 19, 12:4:232 +D, 73, 243:53:7835 +E, 88, 1345:23:532532:353 \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 8474d850c9c6..b142da9dd6a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -87,6 +87,27 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { Seq("key"), Seq.empty) + createPruningTest("Column pruning - explode with aggregate", + "SELECT name, sum(d) AS sumd FROM person LATERAL VIEW explode(data) d AS d GROUP BY name", + Seq("name", "sumd"), + Seq("name","data"), + Seq.empty) + + createPruningTest("Column pruning - outer explode with limit", + "SELECT name FROM person LATERAL VIEW OUTER explode(data) outd AS d" + + " where name < \"C\" limit 3", + Seq("name"), + Seq("name", "data"), + Seq.empty) + + createPruningTest(s"Column pruning - select all without explode optimze - query test", + "SELECT * FROM person LATERAL VIEW OUTER explode(data) outd AS d WHERE 20 < age", + Seq("name", "age", "data", "d"), + Seq("name", "age", "data"), + Seq.empty) + + + // Partition pruning tests createPruningTest("Partition pruning - non-partitioned, non-trivial project",