-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24781][SQL] Using a reference from Dataset in Filter/Sort might not work #21745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
97837a4
b99d0c7
38a935d
eff3af2
6eda8d2
8432b00
860d433
a98f416
9e00db9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1129,7 +1129,8 @@ class Analyzer( | |
| case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa | ||
| case sa @ Sort(_, _, child: Aggregate) => sa | ||
|
|
||
| case s @ Sort(order, _, child) if !s.resolved && child.resolved => | ||
| case s @ Sort(order, _, child) | ||
| if (!s.resolved || s.missingInput.nonEmpty) && child.resolved => | ||
| val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) | ||
| val ordering = newOrder.map(_.asInstanceOf[SortOrder]) | ||
| if (child.output == newChild.output) { | ||
|
|
@@ -1140,7 +1141,7 @@ class Analyzer( | |
| Project(child.output, newSort) | ||
| } | ||
|
|
||
| case f @ Filter(cond, child) if !f.resolved && child.resolved => | ||
| case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved => | ||
| val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) | ||
| if (child.output == newChild.output) { | ||
| f.copy(condition = newCond.head) | ||
|
|
@@ -1151,10 +1152,17 @@ class Analyzer( | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * This method tries to resolve expressions and find missing attributes recursively. Specially, | ||
| * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved | ||
| * attributes which are missed from child output. This method tries to find the missing | ||
| * attributes out and add into the projection. | ||
| */ | ||
| private def resolveExprsAndAddMissingAttrs( | ||
| exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { | ||
| if (exprs.forall(_.resolved)) { | ||
| // All given expressions are resolved, no need to continue anymore. | ||
| // Missing attributes can be unresolved attributes or resolved attributes which are not in | ||
| // the output attributes of the plan. | ||
| if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { | ||
| (exprs, plan) | ||
| } else { | ||
| plan match { | ||
|
|
@@ -1165,15 +1173,19 @@ class Analyzer( | |
| (newExprs, AnalysisBarrier(newChild)) | ||
|
|
||
| case p: Project => | ||
| // Resolving expressions against current plan. | ||
| val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) | ||
| // Recursively resolving expressions on the child of current plan. | ||
| val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) | ||
| val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) | ||
| // If some attributes used by expressions are resolvable only on the rewritten child | ||
| // plan, we need to add them into original projection. | ||
| val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if we do not do the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this |
||
| (newExprs, Project(p.projectList ++ missingAttrs, newChild)) | ||
|
|
||
| case a @ Aggregate(groupExprs, aggExprs, child) => | ||
| val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) | ||
| val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) | ||
| val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) | ||
| val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) | ||
| if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { | ||
| // All the missing attributes are grouping expressions, valid case. | ||
| (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) | ||
|
|
@@ -1493,7 +1505,11 @@ class Analyzer( | |
|
|
||
| // Try resolving the ordering as though it is in the aggregate clause. | ||
| try { | ||
| val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) | ||
| // If a sort order is unresolved, containing references not in aggregate, or containing | ||
| // `AggregateExpression`, we need to push down it to the underlying aggregate operator. | ||
| val unresolvedSortOrders = sortOrder.filter { s => | ||
| !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) | ||
| } | ||
| val aliasedOrdering = | ||
| unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) | ||
| val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2387,4 +2387,25 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { | |
| val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) | ||
| checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) | ||
| } | ||
|
|
||
| test("SPARK-24781: Using a reference from Dataset in Filter/Sort might not work") { | ||
| val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") | ||
| val filter1 = df.select(df("name")).filter(df("id") === 0) | ||
| val filter2 = df.select(col("name")).filter(col("id") === 0) | ||
| checkAnswer(filter1, filter2.collect()) | ||
|
|
||
| val sort1 = df.select(df("name")).orderBy(df("id")) | ||
| val sort2 = df.select(col("name")).orderBy(col("id")) | ||
| checkAnswer(sort1, sort2.collect()) | ||
|
|
||
| withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { | ||
|
||
| val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) | ||
| val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) | ||
| checkAnswer(aggPlusSort1, aggPlusSort2.collect()) | ||
|
|
||
| val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) | ||
| val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) | ||
| checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a second time, but we need to fix in
Aggregatecase? The logic seems completely different. Or can we removeAggregatecase ifResolveAggregateFunctionscan handle this? I don't think we have any reason to keep a wrong logic.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I think it's better to have a re-producible test case before changing
Aggregatecase. I'm trying to create a test case for it. Then it can be more confident to changeAggregatecase.Actually I found another place we need to fix. Seems we don't have enough test coverage for similar features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic gets convoluted here and we need to add comments. Basically we need to explain when we should expand the project list.