Skip to content

Commit e3201e1

Browse files
maryannxuegatorsmile
authored andcommitted
[SPARK-24035][SQL] SQL syntax for Pivot
## What changes were proposed in this pull request? Add SQL support for Pivot according to Pivot grammar defined by Oracle (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_clause.htm) with some simplifications, based on our existing functionality and limitations for Pivot at the backend: 1. For pivot_for_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_for_clause.htm), the column list form is not supported, which means the pivot column can only be one single column. 2. For pivot_in_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_in_clause.htm), the sub-query form and "ANY" is not supported (this is only supported by Oracle for XML anyway). 3. For pivot_in_clause, aliases for the constant values are not supported. The code changes are: 1. Add parser support for Pivot. Note that according to https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#i2076542, Pivot cannot be used together with lateral views in the from clause. This restriction has been implemented in the Parser rule. 2. Infer group-by expressions: group-by expressions are not explicitly specified in SQL Pivot clause and need to be deduced based on this rule: https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#CHDFAFIE, so we have to post-fix it at query analysis stage. 3. Override Pivot.resolved as "false": for the reason mentioned in [2] and the fact that output attributes change after Pivot being replaced by Project or Aggregate, we avoid resolving parent references until after Pivot has been resolved and replaced. 4. Verify aggregate expressions: only aggregate expressions with or without aliases can appear in the first part of the Pivot clause, and this check is performed as analysis stage. ## How was this patch tested? A new test suite PivotSuite is added. Author: maryannxue <[email protected]> Closes #21187 from maryannxue/spark-24035.
1 parent 94641fe commit e3201e1

File tree

8 files changed

+386
-23
lines changed

8 files changed

+386
-23
lines changed

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ hintStatement
398398
;
399399

400400
fromClause
401-
: FROM relation (',' relation)* lateralView*
401+
: FROM relation (',' relation)* (pivotClause | lateralView*)?
402402
;
403403

404404
aggregation
@@ -413,6 +413,10 @@ groupingSet
413413
| expression
414414
;
415415

416+
pivotClause
417+
: PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')'
418+
;
419+
416420
lateralView
417421
: LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)?
418422
;
@@ -725,7 +729,7 @@ nonReserved
725729
| ADD
726730
| OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER
727731
| MAP | ARRAY | STRUCT
728-
| LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER
732+
| PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER
729733
| DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED
730734
| EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS
731735
| GROUPING | CUBE | ROLLUP
@@ -745,7 +749,7 @@ nonReserved
745749
| REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE
746750
| ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH
747751
| ASC | DESC | LIMIT | RENAME | SETS
748-
| AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE
752+
| AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE
749753
| DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE
750754
| NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE
751755
| AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN
@@ -760,6 +764,7 @@ FROM: 'FROM';
760764
ADD: 'ADD';
761765
AS: 'AS';
762766
ALL: 'ALL';
767+
ANY: 'ANY';
763768
DISTINCT: 'DISTINCT';
764769
WHERE: 'WHERE';
765770
GROUP: 'GROUP';
@@ -805,6 +810,7 @@ RIGHT: 'RIGHT';
805810
FULL: 'FULL';
806811
NATURAL: 'NATURAL';
807812
ON: 'ON';
813+
PIVOT: 'PIVOT';
808814
LATERAL: 'LATERAL';
809815
WINDOW: 'WINDOW';
810816
OVER: 'OVER';

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ class Analyzer(
275275
case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
276276
g.copy(aggregations = assignAliases(g.aggregations))
277277

278-
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
279-
if child.resolved && hasUnresolvedAlias(groupByExprs) =>
280-
Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)
278+
case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child)
279+
if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) =>
280+
Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child)
281281

282282
case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
283283
Project(assignAliases(projectList), child)
@@ -504,9 +504,20 @@ class Analyzer(
504504

505505
object ResolvePivot extends Rule[LogicalPlan] {
506506
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
507-
case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved)
508-
| !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p
509-
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
507+
case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved)
508+
|| (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved))
509+
|| !p.pivotColumn.resolved => p
510+
case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
511+
// Check all aggregate expressions.
512+
aggregates.foreach { e =>
513+
if (!isAggregateExpression(e)) {
514+
throw new AnalysisException(
515+
s"Aggregate expression required for pivot, found '$e'")
516+
}
517+
}
518+
// Group-by expressions coming from SQL are implicit and need to be deduced.
519+
val groupByExprs = groupByExprsOpt.getOrElse(
520+
(child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq)
510521
val singleAgg = aggregates.size == 1
511522
def outputName(value: Literal, aggregate: Expression): String = {
512523
val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
@@ -568,16 +579,20 @@ class Analyzer(
568579
// TODO: Don't construct the physical container until after analysis.
569580
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
570581
}
571-
if (filteredAggregate.fastEquals(aggregate)) {
572-
throw new AnalysisException(
573-
s"Aggregate expression required for pivot, found '$aggregate'")
574-
}
575582
Alias(filteredAggregate, outputName(value, aggregate))()
576583
}
577584
}
578585
Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
579586
}
580587
}
588+
589+
private def isAggregateExpression(expr: Expression): Boolean = {
590+
expr match {
591+
case Alias(e, _) => isAggregateExpression(e)
592+
case AggregateExpression(_, _, _, _) => true
593+
case _ => false
594+
}
595+
}
581596
}
582597

583598
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
503503
val join = right.optionalMap(left)(Join(_, _, Inner, None))
504504
withJoinRelations(join, relation)
505505
}
506-
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
506+
if (ctx.pivotClause() != null) {
507+
withPivot(ctx.pivotClause, from)
508+
} else {
509+
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
510+
}
507511
}
508512

509513
/**
@@ -614,6 +618,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
614618
plan
615619
}
616620

621+
/**
622+
* Add a [[Pivot]] to a logical plan.
623+
*/
624+
private def withPivot(
625+
ctx: PivotClauseContext,
626+
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
627+
val aggregates = Option(ctx.aggregates).toSeq
628+
.flatMap(_.namedExpression.asScala)
629+
.map(typedVisit[Expression])
630+
val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText)
631+
val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply)
632+
Pivot(None, pivotColumn, pivotValues, aggregates, query)
633+
}
634+
617635
/**
618636
* Add a [[Generate]] (Lateral View) to a logical plan.
619637
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -686,17 +686,34 @@ case class GroupingSets(
686686
override lazy val resolved: Boolean = false
687687
}
688688

689+
/**
690+
* A constructor for creating a pivot, which will later be converted to a [[Project]]
691+
* or an [[Aggregate]] during the query analysis.
692+
*
693+
* @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming
694+
* from SQL, in which group by expressions are not explicitly specified.
695+
* @param pivotColumn The pivot column.
696+
* @param pivotValues A sequence of values for the pivot column.
697+
* @param aggregates The aggregation expressions, each with or without an alias.
698+
* @param child Child operator
699+
*/
689700
case class Pivot(
690-
groupByExprs: Seq[NamedExpression],
701+
groupByExprsOpt: Option[Seq[NamedExpression]],
691702
pivotColumn: Expression,
692703
pivotValues: Seq[Literal],
693704
aggregates: Seq[Expression],
694705
child: LogicalPlan) extends UnaryNode {
695-
override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
696-
case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
697-
case _ => pivotValues.flatMap{ value =>
698-
aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)())
706+
override lazy val resolved = false // Pivot will be replaced after being resolved.
707+
override def output: Seq[Attribute] = {
708+
val pivotAgg = aggregates match {
709+
case agg :: Nil =>
710+
pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
711+
case _ =>
712+
pivotValues.flatMap { value =>
713+
aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)())
714+
}
699715
}
716+
groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg
700717
}
701718
}
702719

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ class TableIdentifierParserSuite extends SparkFunSuite {
4141
"sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables",
4242
"tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive",
4343
"undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp",
44-
"view", "while", "year", "work", "transaction", "write", "isolation", "level",
45-
"snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint",
44+
"view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot",
45+
"autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint",
4646
"binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp",
4747
"cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external",
4848
"false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in",
49-
"insert", "int", "into", "is", "lateral", "like", "local", "none", "null",
49+
"insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null",
5050
"of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke",
5151
"rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger",
5252
"true", "truncate", "update", "user", "values", "with", "regexp", "rlike",

sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class RelationalGroupedDataset protected[sql](
7373
case RelationalGroupedDataset.PivotType(pivotCol, values) =>
7474
val aliasedGrps = groupingExprs.map(alias)
7575
Dataset.ofRows(
76-
df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
76+
df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan))
7777
}
7878
}
7979

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
create temporary view courseSales as select * from values
2+
("dotNET", 2012, 10000),
3+
("Java", 2012, 20000),
4+
("dotNET", 2012, 5000),
5+
("dotNET", 2013, 48000),
6+
("Java", 2013, 30000)
7+
as courseSales(course, year, earnings);
8+
9+
create temporary view years as select * from values
10+
(2012, 1),
11+
(2013, 2)
12+
as years(y, s);
13+
14+
-- pivot courses
15+
SELECT * FROM (
16+
SELECT year, course, earnings FROM courseSales
17+
)
18+
PIVOT (
19+
sum(earnings)
20+
FOR course IN ('dotNET', 'Java')
21+
);
22+
23+
-- pivot years with no subquery
24+
SELECT * FROM courseSales
25+
PIVOT (
26+
sum(earnings)
27+
FOR year IN (2012, 2013)
28+
);
29+
30+
-- pivot courses with multiple aggregations
31+
SELECT * FROM (
32+
SELECT year, course, earnings FROM courseSales
33+
)
34+
PIVOT (
35+
sum(earnings), avg(earnings)
36+
FOR course IN ('dotNET', 'Java')
37+
);
38+
39+
-- pivot with no group by column
40+
SELECT * FROM (
41+
SELECT course, earnings FROM courseSales
42+
)
43+
PIVOT (
44+
sum(earnings)
45+
FOR course IN ('dotNET', 'Java')
46+
);
47+
48+
-- pivot with no group by column and with multiple aggregations on different columns
49+
SELECT * FROM (
50+
SELECT year, course, earnings FROM courseSales
51+
)
52+
PIVOT (
53+
sum(earnings), min(year)
54+
FOR course IN ('dotNET', 'Java')
55+
);
56+
57+
-- pivot on join query with multiple group by columns
58+
SELECT * FROM (
59+
SELECT course, year, earnings, s
60+
FROM courseSales
61+
JOIN years ON year = y
62+
)
63+
PIVOT (
64+
sum(earnings)
65+
FOR s IN (1, 2)
66+
);
67+
68+
-- pivot on join query with multiple aggregations on different columns
69+
SELECT * FROM (
70+
SELECT course, year, earnings, s
71+
FROM courseSales
72+
JOIN years ON year = y
73+
)
74+
PIVOT (
75+
sum(earnings), min(s)
76+
FOR course IN ('dotNET', 'Java')
77+
);
78+
79+
-- pivot on join query with multiple columns in one aggregation
80+
SELECT * FROM (
81+
SELECT course, year, earnings, s
82+
FROM courseSales
83+
JOIN years ON year = y
84+
)
85+
PIVOT (
86+
sum(earnings * s)
87+
FOR course IN ('dotNET', 'Java')
88+
);
89+
90+
-- pivot with aliases and projection
91+
SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM (
92+
SELECT year y, course c, earnings e FROM courseSales
93+
)
94+
PIVOT (
95+
sum(e) s, avg(e) a
96+
FOR y IN (2012, 2013)
97+
);
98+
99+
-- pivot years with non-aggregate function
100+
SELECT * FROM courseSales
101+
PIVOT (
102+
abs(earnings)
103+
FOR year IN (2012, 2013)
104+
);
105+
106+
-- pivot with unresolvable columns
107+
SELECT * FROM (
108+
SELECT course, earnings FROM courseSales
109+
)
110+
PIVOT (
111+
sum(earnings)
112+
FOR year IN (2012, 2013)
113+
);

0 commit comments

Comments
 (0)