Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
SPARK-24696 ColumnPruning rule fails to remove extra Project
  • Loading branch information
maryannxue committed Jun 29, 2018
commit 11fde8ba4b64416d863a69c5587c0db67ea61d0a
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,10 @@ object ColumnPruning extends Rule[LogicalPlan] {

/**
* The Project before Filter is not necessary but conflict with PushPredicatesThroughProject,
* so remove it.
* so remove it. Since the Projects have been added top-down, we need to remove in bottom-up
* order, otherwise lower Projects can be missed.
*/
private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform {
private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
if p2.outputSet.subsetOf(child.outputSet) =>
p1.copy(child = f.copy(child = child))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -370,5 +369,13 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(optimized2, expected2.analyze)
}

test("SPARK-24696 ColumnPruning rule fails to remove extra Project") {
val input = LocalRelation('key.int, 'value.string)
val query = input.select('key).where('key > 1).where('key < 10).analyze
val optimized = Optimize.execute(query)
val expected = input.where('key > 1).where('key < 10).select('key).analyze
comparePlans(optimized, expected)
}

// todo: add more tests for column pruning
}
21 changes: 21 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2792,4 +2792,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-24696 ColumnPruning rule fails to remove extra Project") {
Copy link
Member

@viirya viirya Jun 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test in Jira is simpler than this. Do we need to have two tables and a join? Why not just use the test in Jira?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new unit test in ColumnPruningSuite.scala already covers that.

withTable("fact_stats", "dim_stats") {
val factData = Seq((1, 1, 99, 1), (2, 2, 99, 2), (3, 1, 99, 3), (4, 2, 99, 4))
val storeData = Seq((1, "BW", "DE"), (2, "AZ", "US"))
spark.udf.register("filterND", udf((value: Int) => value > 2).asNondeterministic)
factData.toDF("date_id", "store_id", "product_id", "units_sold")
.write.mode("overwrite").partitionBy("store_id").format("parquet").saveAsTable("fact_stats")
storeData.toDF("store_id", "state_province", "country")
.write.mode("overwrite").format("parquet").saveAsTable("dim_stats")
val df = sql(
"""
|SELECT f.date_id, f.product_id, f.store_id FROM
|(SELECT date_id, product_id, store_id
| FROM fact_stats WHERE filterND(date_id)) AS f
|JOIN dim_stats s
|ON f.store_id = s.store_id WHERE s.country = 'DE'
""".stripMargin)
checkAnswer(df, Seq(Row(3, 99, 1)))
}
}
}