Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
04959c2
refactor analyzer adding a new object
anchovYu Nov 23, 2022
6f44c85
lca code
anchovYu Nov 23, 2022
725e5ac
add tests, refine logic
anchovYu Nov 28, 2022
660e1d2
move lca rule to a new file
anchovYu Nov 28, 2022
fd06094
rename conf
anchovYu Nov 28, 2022
7d4f80f
test failure
anchovYu Nov 29, 2022
b9704d5
small fix
anchovYu Nov 29, 2022
777f13a
temp commit, still in implementation
anchovYu Nov 29, 2022
09480ea
a temporary solution, but still fail certain cases
anchovYu Nov 30, 2022
c972738
working solution, needs some refinement
anchovYu Dec 1, 2022
97ee293
Merge remote-tracking branch 'apache/master' into SPARK-27561-refactor
anchovYu Dec 1, 2022
5785943
make changes to accomodate the recent refactor
anchovYu Dec 2, 2022
757cffb
introduce leaf exp in Project as well
anchovYu Dec 5, 2022
29de892
handle a corner case
anchovYu Dec 5, 2022
72991c6
add more tests; add check rule
anchovYu Dec 6, 2022
d45fe31
uplift the necessity to resolve expression in second phase; add more …
anchovYu Dec 8, 2022
1f55f73
address comments to add tests for LCA off
anchovYu Dec 8, 2022
f753529
revert the refactor, split LCA into two rules
anchovYu Dec 9, 2022
b9f706f
better refactor
anchovYu Dec 9, 2022
94d5c9e
address comments
anchovYu Dec 9, 2022
d2e75fd
Merge branch 'SPARK-27561-refactor' into SPARK-27561-agg
anchovYu Dec 9, 2022
edde37c
basic version passing all tests
anchovYu Dec 9, 2022
fb7b18c
update the logic, add and refactor tests
anchovYu Dec 12, 2022
3698cff
update comments
anchovYu Dec 13, 2022
e700d6a
add a corner case comment
anchovYu Dec 13, 2022
8d20986
address comments
anchovYu Dec 13, 2022
d952aa7
Merge branch 'SPARK-27561-refactor' into SPARK-27561-agg
anchovYu Dec 13, 2022
44d5a3d
Merge remote-tracking branch 'apache/master' into SPARK-27561-agg
anchovYu Dec 13, 2022
ccebc1c
revert some changes
anchovYu Dec 13, 2022
5540b70
fix few todos
anchovYu Dec 13, 2022
338ba11
Merge remote-tracking branch 'apache/master' into SPARK-27561-agg
anchovYu Dec 16, 2022
136a930
fix the failing test
anchovYu Dec 16, 2022
5076ad2
fix the missing_aggregate issue, turn on conf to see failed tests
anchovYu Dec 19, 2022
2f2dee5
remove few todos
anchovYu Dec 19, 2022
3a5509a
better fix to maintain aggregate error: only lift up in certain cases
anchovYu Dec 20, 2022
a23debb
Merge remote-tracking branch 'apache/master' into SPARK-27561-agg
anchovYu Dec 20, 2022
b200da0
typo
anchovYu Dec 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
better fix to maintain aggregate error: only lift up in certain cases
  • Loading branch information
anchovYu committed Dec 20, 2022
commit 3a5509aa56218be561eb391cf116f1e6c406f560
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Expression, LateralColumnAliasReference, NamedExpression, OuterReference, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, Expression, LateralColumnAliasReference, LeafExpression, Literal, NamedExpression, OuterReference, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -307,6 +307,27 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) if agg.resolved
&& aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>

// Check if current Aggregate is eligible to lift up with Project: the aggregate
// expression only contains: 1) aggregate functions, 2) grouping expressions, 3) lateral
// column alias reference or 4) literals.
// This check is to prevent unnecessary transformation on invalid plan, to guarantee it
// throws the same exception. For example, cases like non-aggregate expressions not
// in group by, once transformed, will throw a different exception: missing input.
def eligibleToLiftUp(exp: Expression): Boolean = {
exp match {
case e: AggregateExpression if AggregateExpression.isAggregate(e) => true
case e if groupingExpressions.exists(_.semanticEquals(e)) => true
case _: Literal | _: LateralColumnAliasReference => true
case s: ScalarSubquery if s.children.nonEmpty
&& !groupingExpressions.exists(_.semanticEquals(s)) => false
case _: LeafExpression => false
case e => e.children.forall(eligibleToLiftUp)
}
}
if (!aggregateExpressions.forall(eligibleToLiftUp)) {
return agg
}

val newAggExprs = collection.mutable.Set.empty[NamedExpression]
val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression]
val projectExprs = aggregateExpressions.map { exp =>
Expand All @@ -332,10 +353,6 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
if (newAggExprs.isEmpty) {
agg
} else {
// perform an early check on current Aggregate before any lift-up / push-down to throw
// the same exception such as non-aggregate expressions not in group by, which becomes
// missing input after transformation
earlyCheckAggregate(agg)
Project(
projectList = projectExprs,
child = agg.copy(aggregateExpressions = newAggExprs.toSeq)
Expand All @@ -344,26 +361,4 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
}
}
}

private def earlyCheckAggregate(plan: Aggregate): Unit = {
val Aggregate(groupingExprs, aggregateExprs, _) = plan
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case expr: Expression if AggregateExpression.isAggregate(expr) =>
// doesn't perform any check on aggregation functions
case _: Attribute if groupingExprs.isEmpty =>
plan.failAnalysis(
errorClass = "MISSING_GROUP_BY",
messageParameters = Map.empty)
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
throw QueryCompilationErrors.columnNotInGroupByClauseError(e)
case s: ScalarSubquery
if s.children.nonEmpty && !groupingExprs.exists(_.semanticEquals(s)) =>
s.failAnalysis(
errorClass = "_LEGACY_ERROR_TEMP_2423",
messageParameters = Map("sqlExpr" -> s.sql))
case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
aggregateExprs.foreach(checkValidAggregateExpression)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -714,37 +714,57 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase {
}
}

test("Non-aggregating expressions not in group by still throws the same error") {
// query without lateral alias
assert(
intercept[AnalysisException] {
sql(s"SELECT dept AS a, salary FROM $testTable GROUP BY dept")
}.getErrorClass == "MISSING_AGGREGATION")

assert(
intercept[AnalysisException] {
sql(s"SELECT avg(salary), avg(avg(salary)) FROM $testTable GROUP BY dept")
}.getErrorClass == "NESTED_AGGREGATE_FUNCTION")

// query with lateral alias throws the same error
assert(
intercept[AnalysisException] {
sql(s"SELECT dept AS a, a, salary FROM $testTable GROUP BY dept")
}.getErrorClass == "MISSING_AGGREGATION")
// no longer throws NESTED_AGGREGATE_FUNCTION but UNSUPPORTED_FEATURE
assert(
intercept[AnalysisException] {
sql(s"SELECT avg(salary) AS a, avg(a) FROM $testTable GROUP BY dept")
}.getErrorClass == "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC")

// checkAnalysis doesn't canonicalize expressions when performing check of non-aggregation
// expression in group by. With LCA, it doesn't change and throw same exception
val e1 = intercept[AnalysisException] {
sql(s"SELECT avg(salary) AS a, avg(salary) + dept + 10 FROM $testTable GROUP BY dept + 10")
}
val e2 = intercept[AnalysisException] {
sql(s"SELECT avg(salary) AS a, a + dept + 10 FROM $testTable GROUP BY dept + 10")
test("Aggregate expressions not eligible to lift up, throws same error as inline") {
def checkSameMissingAggregationError(q1: String, q2: String, expressionParam: String): Unit = {
Seq(q1, q2).foreach { query =>
val e = intercept[AnalysisException] { sql(query) }
assert(e.getErrorClass == "MISSING_AGGREGATION")
assert(e.messageParameters.get("expression").exists(_ == expressionParam))
}
}
assert(e1.getErrorClass == e2.getErrorClass)

val suffix = s"FROM $testTable GROUP BY dept"
checkSameMissingAggregationError(
s"SELECT dept AS a, dept, salary $suffix",
s"SELECT dept AS a, a, salary $suffix",
"\"salary\"")
checkSameMissingAggregationError(
s"SELECT dept AS a, dept + salary $suffix",
s"SELECT dept AS a, a + salary $suffix",
"\"salary\"")
checkSameMissingAggregationError(
s"SELECT avg(salary) AS a, avg(salary) + bonus $suffix",
s"SELECT avg(salary) AS a, a + bonus $suffix",
"\"bonus\"")
checkSameMissingAggregationError(
s"SELECT dept AS a, dept, avg(salary) + bonus + 10 $suffix",
s"SELECT dept AS a, a, avg(salary) + bonus + 10 $suffix",
"\"bonus\"")
checkSameMissingAggregationError(
s"SELECT avg(salary) AS a, avg(salary), dept FROM $testTable GROUP BY dept + 10",
s"SELECT avg(salary) AS a, a, dept FROM $testTable GROUP BY dept + 10",
"\"dept\"")
checkSameMissingAggregationError(
s"SELECT avg(salary) AS a, avg(salary) + dept + 10 FROM $testTable GROUP BY dept + 10",
s"SELECT avg(salary) AS a, a + dept + 10 FROM $testTable GROUP BY dept + 10",
"\"dept\"")
Seq(
s"SELECT dept AS a, dept, " +
s"(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept) $suffix",
s"SELECT dept AS a, a, " +
s"(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept) $suffix"
).foreach { query =>
val e = intercept[AnalysisException] { sql(query) }
assert(e.getErrorClass == "_LEGACY_ERROR_TEMP_2423") }

// one exception: no longer throws NESTED_AGGREGATE_FUNCTION but UNSUPPORTED_FEATURE
checkError(
exception = intercept[AnalysisException] {
sql(s"SELECT avg(salary) AS a, avg(a) FROM $testTable GROUP BY dept")
},
errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC",
sqlState = "0A000",
parameters = Map("lca" -> "`a`", "aggFunc" -> "\"avg(lateralAliasReference(a))\"")
)
}
}