Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ hintStatement
;

fromClause
: FROM relation (',' relation)* lateralView*
: FROM relation (',' relation)* (pivotClause | lateralView*)?
;

aggregation
Expand All @@ -413,6 +413,10 @@ groupingSet
| expression
;

pivotClause
: PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')'
;

lateralView
: LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)?
;
Expand Down Expand Up @@ -725,7 +729,7 @@ nonReserved
| ADD
| OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER
| MAP | ARRAY | STRUCT
| LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER
| PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER
| DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED
| EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS
| GROUPING | CUBE | ROLLUP
Expand All @@ -745,7 +749,7 @@ nonReserved
| REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE
| ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH
| ASC | DESC | LIMIT | RENAME | SETS
| AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE
| AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE
| DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE
| NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE
| AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN
Expand All @@ -760,6 +764,7 @@ FROM: 'FROM';
ADD: 'ADD';
AS: 'AS';
ALL: 'ALL';
ANY: 'ANY';
DISTINCT: 'DISTINCT';
WHERE: 'WHERE';
GROUP: 'GROUP';
Expand Down Expand Up @@ -805,6 +810,7 @@ RIGHT: 'RIGHT';
FULL: 'FULL';
NATURAL: 'NATURAL';
ON: 'ON';
PIVOT: 'PIVOT';
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add the keywords you added here to nonReserved (line 723)? Also update the suite TableIdentifierParserSuite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll update TableIndentifierParserSuite. I believe I've added them to nonReserved already. Did I miss something?

LATERAL: 'LATERAL';
WINDOW: 'WINDOW';
OVER: 'OVER';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ class Analyzer(
case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
g.copy(aggregations = assignAliases(g.aggregations))

case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
if child.resolved && hasUnresolvedAlias(groupByExprs) =>
Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)
case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child)
if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) =>
Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child)

case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
Project(assignAliases(projectList), child)
Expand Down Expand Up @@ -504,9 +504,20 @@ class Analyzer(

object ResolvePivot extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved)
| !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved)
|| (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved))
|| !p.pivotColumn.resolved => p
case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
// Check all aggregate expressions.
aggregates.foreach { e =>
if (!isAggregateExpression(e)) {
throw new AnalysisException(
s"Aggregate expression required for pivot, found '$e'")
}
}
// Group-by expressions coming from SQL are implicit and need to be deduced.
val groupByExprs = groupByExprsOpt.getOrElse(
(child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq)
val singleAgg = aggregates.size == 1
def outputName(value: Literal, aggregate: Expression): String = {
val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
Expand Down Expand Up @@ -568,16 +579,20 @@ class Analyzer(
// TODO: Don't construct the physical container until after analysis.
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
s"Aggregate expression required for pivot, found '$aggregate'")
}
Alias(filteredAggregate, outputName(value, aggregate))()
}
}
Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
}
}

private def isAggregateExpression(expr: Expression): Boolean = {
expr match {
case Alias(e, _) => isAggregateExpression(e)
case AggregateExpression(_, _, _, _) => true
case _ => false
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val join = right.optionalMap(left)(Join(_, _, Inner, None))
withJoinRelations(join, relation)
}
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
if (ctx.pivotClause() != null) {
withPivot(ctx.pivotClause, from)
} else {
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
}
}

/**
Expand Down Expand Up @@ -614,6 +618,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
plan
}

/**
* Add a [[Pivot]] to a logical plan.
*/
private def withPivot(
ctx: PivotClauseContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
val aggregates = Option(ctx.aggregates).toSeq
.flatMap(_.namedExpression.asScala)
.map(typedVisit[Expression])
val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText)
val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply)
Pivot(None, pivotColumn, pivotValues, aggregates, query)
}

/**
* Add a [[Generate]] (Lateral View) to a logical plan.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,17 +686,34 @@ case class GroupingSets(
override lazy val resolved: Boolean = false
}

/**
* A constructor for creating a pivot, which will later be converted to a [[Project]]
* or an [[Aggregate]] during the query analysis.
*
* @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming
* from SQL, in which group by expressions are not explicitly specified.
* @param pivotColumn The pivot column.
* @param pivotValues A sequence of values for the pivot column.
* @param aggregates The aggregation expressions, each with or without an alias.
* @param child Child operator
*/
case class Pivot(
groupByExprs: Seq[NamedExpression],
groupByExprsOpt: Option[Seq[NamedExpression]],
pivotColumn: Expression,
pivotValues: Seq[Literal],
aggregates: Seq[Expression],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
case _ => pivotValues.flatMap{ value =>
aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)())
override lazy val resolved = false // Pivot will be replaced after being resolved.
override def output: Seq[Attribute] = {
val pivotAgg = aggregates match {
case agg :: Nil =>
pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
case _ =>
pivotValues.flatMap { value =>
aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)())
}
}
groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ class TableIdentifierParserSuite extends SparkFunSuite {
"sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables",
"tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive",
"undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp",
"view", "while", "year", "work", "transaction", "write", "isolation", "level",
"snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint",
"view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot",
"autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint",
"binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp",
"cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external",
"false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in",
"insert", "int", "into", "is", "lateral", "like", "local", "none", "null",
"insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null",
"of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke",
"rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger",
"true", "truncate", "update", "user", "values", "with", "regexp", "rlike",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class RelationalGroupedDataset protected[sql](
case RelationalGroupedDataset.PivotType(pivotCol, values) =>
val aliasedGrps = groupingExprs.map(alias)
Dataset.ofRows(
df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan))
}
}

Expand Down
113 changes: 113 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/pivot.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
create temporary view courseSales as select * from values
("dotNET", 2012, 10000),
("Java", 2012, 20000),
("dotNET", 2012, 5000),
("dotNET", 2013, 48000),
("Java", 2013, 30000)
as courseSales(course, year, earnings);

create temporary view years as select * from values
(2012, 1),
(2013, 2)
as years(y, s);

-- pivot courses
SELECT * FROM (
SELECT year, course, earnings FROM courseSales
)
PIVOT (
sum(earnings)
FOR course IN ('dotNET', 'Java')
);

-- pivot years with no subquery
SELECT * FROM courseSales
PIVOT (
sum(earnings)
FOR year IN (2012, 2013)
);

-- pivot courses with multiple aggregations
SELECT * FROM (
SELECT year, course, earnings FROM courseSales
)
PIVOT (
sum(earnings), avg(earnings)
FOR course IN ('dotNET', 'Java')
);

-- pivot with no group by column
SELECT * FROM (
SELECT course, earnings FROM courseSales
)
PIVOT (
sum(earnings)
FOR course IN ('dotNET', 'Java')
);

-- pivot with no group by column and with multiple aggregations on different columns
SELECT * FROM (
SELECT year, course, earnings FROM courseSales
)
PIVOT (
sum(earnings), min(year)
FOR course IN ('dotNET', 'Java')
);

-- pivot on join query with multiple group by columns
SELECT * FROM (
SELECT course, year, earnings, s
FROM courseSales
JOIN years ON year = y
)
PIVOT (
sum(earnings)
FOR s IN (1, 2)
);

-- pivot on join query with multiple aggregations on different columns
SELECT * FROM (
SELECT course, year, earnings, s
FROM courseSales
JOIN years ON year = y
)
PIVOT (
sum(earnings), min(s)
FOR course IN ('dotNET', 'Java')
);

-- pivot on join query with multiple columns in one aggregation
SELECT * FROM (
SELECT course, year, earnings, s
FROM courseSales
JOIN years ON year = y
)
PIVOT (
sum(earnings * s)
FOR course IN ('dotNET', 'Java')
);

-- pivot with aliases and projection
SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM (
SELECT year y, course c, earnings e FROM courseSales
)
PIVOT (
sum(e) s, avg(e) a
FOR y IN (2012, 2013)
);

-- pivot years with non-aggregate function
SELECT * FROM courseSales
PIVOT (
abs(earnings)
FOR year IN (2012, 2013)
);

-- pivot with unresolvable columns
SELECT * FROM (
SELECT course, earnings FROM courseSales
)
PIVOT (
sum(earnings)
FOR year IN (2012, 2013)
);
Loading