Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -646,46 +646,72 @@ object FoldablePropagation extends Rule[LogicalPlan] {
}
case _ => Nil
})
val replaceFoldable: PartialFunction[Expression, Expression] = {
case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
}

if (foldableMap.isEmpty) {
plan
} else {
var stop = false
CleanupAliases(plan.transformUp {
case u: Union =>
stop = true
u
case c: Command =>
stop = true
c
// For outer join, although its output attributes are derived from its children, they are
// actually different attributes: the output of outer join is not always picked from its
// children, but can also be null.
// A leaf node should not stop the folding process (note that we are traversing up the
// tree, starting at the leaf nodes); so we are allowing it.
case l: LeafNode =>
l

// We can only propagate foldables for a subset of unary nodes.
case u: UnaryNode if !stop && canPropagateFoldables(u) =>
u.transformExpressions(replaceFoldable)

// Allow inner joins. We do not allow outer join, although its output attributes are
// derived from its children, they are actually different attributes: the output of outer
// join is not always picked from its children, but can also be null.
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
// of outer join.
case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) =>
case j @ Join(_, _, Inner, _) =>
j.transformExpressions(replaceFoldable)

// We can fold the projections an expand holds. However expand changes the output columns
// and often reuses the underlying attributes; so we cannot assume that a column is still
// foldable after the expand has been applied.
// TODO(hvanhovell): Expand should use new attributes as the output attributes.
case expand: Expand if !stop =>
val newExpand = expand.copy(projections = expand.projections.map { projection =>
projection.map(_.transform(replaceFoldable))
})
stop = true
j
newExpand

// These 3 operators take attributes as constructor parameters, and these attributes
// can't be replaced by alias.
case m: MapGroups =>
stop = true
m
case f: FlatMapGroupsInR =>
stop = true
f
case c: CoGroup =>
case other =>
stop = true
c

case p: LogicalPlan if !stop => p.transformExpressions {
case a: AttributeReference if foldableMap.contains(a) =>
foldableMap(a)
}
other
})
}
}

/**
* Whitelist of all [[UnaryNode]]s for which we allow foldable propagation.
*/
private def canPropagateFoldables(u: UnaryNode): Boolean = u match {
case _: Project => true
case _: Filter => true
case _: SubqueryAlias => true
case _: Aggregate => true
case _: Window => true
case _: Sample => true
case _: GlobalLimit => true
case _: LocalLimit => true
case _: Generate => true
case _: Distinct => true
case _: AppendColumns => true
case _: AppendColumnsWithObject => true
case _: BroadcastHint => true
case _: RedistributeData => true
case _: Repartition => true
case _: Sort => true
case _ => false
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,36 @@ class FoldablePropagationSuite extends PlanTest {
test("Propagate in subqueries of Union queries") {
val query = Union(
Seq(
testRelation.select(Literal(1).as('x), 'a).select('x + 'a),
testRelation.select(Literal(2).as('x), 'a).select('x + 'a)))
testRelation.select(Literal(1).as('x), 'a).select('x, 'x + 'a),
testRelation.select(Literal(2).as('x), 'a).select('x, 'x + 'a)))
.select('x)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = Union(
Seq(
testRelation.select(Literal(1).as('x), 'a).select((Literal(1).as('x) + 'a).as("(x + a)")),
testRelation.select(Literal(2).as('x), 'a).select((Literal(2).as('x) + 'a).as("(x + a)"))))
testRelation.select(Literal(1).as('x), 'a)
.select(Literal(1).as('x), (Literal(1).as('x) + 'a).as("(x + a)")),
testRelation.select(Literal(2).as('x), 'a)
.select(Literal(2).as('x), (Literal(2).as('x) + 'a).as("(x + a)"))))
.select('x).analyze

comparePlans(optimized, correctAnswer)
}

test("Propagate in expand") {
val c1 = Literal(1).as('a)
val c2 = Literal(2).as('b)
val a1 = c1.toAttribute.withNullability(true)
val a2 = c2.toAttribute.withNullability(true)
val expand = Expand(
Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))),
Seq(a1, a2),
OneRowRelation.select(c1, c2))
val query = expand.where(a1.isNotNull).select(a1, a2).analyze
val optimized = Optimize.execute(query)
val correctExpand = expand.copy(projections = Seq(
Seq(Literal(null), c2),
Seq(c1, Literal(null))))
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
comparePlans(optimized, correctAnswer)
}
}
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1;
-- Aggregate with nulls.
SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a)
FROM testData;

-- Aggregate with foldable input and multiple distinct groups.
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
10 changes: 9 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/group-by.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 14
-- Number of queries: 15


-- !query 0
Expand Down Expand Up @@ -131,3 +131,11 @@ FROM testData
struct<skewness(CAST(a AS DOUBLE)):double,kurtosis(CAST(a AS DOUBLE)):double,min(a):int,max(a):int,avg(a):double,var_samp(CAST(a AS DOUBLE)):double,stddev_samp(CAST(a AS DOUBLE)):double,sum(a):bigint,count(a):bigint>
-- !query 13 output
-0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7


-- !query 14
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a
-- !query 14 schema
struct<count(DISTINCT b):bigint,count(DISTINCT b, c):bigint>
-- !query 14 output
1 1