From 0904fc9feabf348c92ebeaff0fd2a506c1cec128 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 18 Apr 2017 09:39:26 +0900 Subject: [PATCH 1/4] Support aliases for table value functions --- .../spark/sql/catalyst/parser/SqlBase.g4 | 20 +++++++--- .../ResolveTableValuedFunctions.scala | 22 ++++++++-- .../sql/catalyst/analysis/unresolved.scala | 10 ++++- .../sql/catalyst/parser/AstBuilder.scala | 17 ++++++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 14 ++++++- .../sql/catalyst/parser/PlanParserSuite.scala | 13 +++++- .../inputs/table-valued-functions.sql | 4 ++ .../results/table-valued-functions.sql.out | 40 ++++++++++++++++++- 8 files changed, 122 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 14c511f67060..ed5450b494cc 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -472,15 +472,23 @@ identifierComment ; relationPrimary - : tableIdentifier sample? (AS? strictIdentifier)? #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery - | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation - | inlineTable #inlineTableDefault2 - | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction + : tableIdentifier sample? (AS? strictIdentifier)? #tableName + | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction ; inlineTable - : VALUES expression (',' expression)* (AS? identifier identifierList?)? + : VALUES expression (',' expression)* tableAlias + ; + +functionTable + : identifier '(' (expression (',' expression)*)? ')' tableAlias + ; + +tableAlias + : (AS? strictIdentifier identifierList?)? ; rowFormat diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index de6de24350f2..dad1340571cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} @@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { + val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { @@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") } + + // If alias names assigned, add `Project` with the aliases + if (u.outputNames.nonEmpty) { + val outputAttrs = resolvedFunc.output + // Checks if the number of the aliases is equal to expected one + if (u.outputNames.size != outputAttrs.size) { + u.failAnalysis(s"expected ${outputAttrs.size} columns but " + + s"found ${u.outputNames.size} columns") + } + val aliases = outputAttrs.zip(u.outputNames).map { + case (attr, name) => Alias(attr, name)() + } + Project(aliases, resolvedFunc) + } else { + resolvedFunc + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 262b894e2a0a..51bef6e20b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -66,10 +66,16 @@ case class UnresolvedInlineTable( /** * A table-valued function, e.g. * {{{ - * select * from range(10); + * select id from range(10); + * + * // Assign alias names + * select t.a from range(10) t(a); * }}} */ -case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression]) +case class UnresolvedTableValuedFunction( + functionName: String, + functionArgs: Seq[Expression], + outputNames: Seq[String]) extends LeafNode { override def output: Seq[Attribute] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2a9b4a9a9f5..046ea65d454a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTableValuedFunction(ctx: TableValuedFunctionContext) : LogicalPlan = withOrigin(ctx) { - UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) + } else { + Seq.empty + } + + val tvf = UnresolvedTableValuedFunction( + func.identifier.getText, func.expression.asScala.map(expression), aliases) + tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) } /** @@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } - val aliases = if (ctx.identifierList != null) { - visitIdentifierList(ctx.identifierList) + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) } else { Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } val table = UnresolvedInlineTable(aliases, rows) - table.optionalMap(ctx.identifier)(aliasPlan) + table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 893bb1b74cea..31047f688600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } + + test("SPARK-20311 range(N) as alias") { + def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = { + SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames)) + .select(star()) + } + assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil)) + assertAnalysisError( + rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil), + Seq("expected 1 columns but found 2 columns")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 411777d6e85a..4c2476296c04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest { test("table valued function") { assertEqual( "select * from range(2)", - UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star())) + UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star())) + } + + test("SPARK-20311 range(N) as alias") { + assertEqual( + "select * from range(10) AS t", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty)) + .select(star())) + assertEqual( + "select * from range(7) AS t(a)", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil)) + .select(star())) } test("inline table") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index d0d2df7b243d..47ca9f1de410 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -24,3 +24,7 @@ select * from RaNgE(2); -- Explain EXPLAIN select * from RaNgE(2); + +-- cross-join table valued functions +set spark.sql.crossJoin.enabled=true; +EXPLAIN EXTENDED select * from range(3) cross join range(3); diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index e2ee970d35f6..8e4fb82bb06f 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 11 -- !query 0 @@ -103,3 +103,41 @@ struct -- !query 8 output == Physical Plan == *Range (0, 2, step=1, splits=2) + + +-- !query 9 +set spark.sql.crossJoin.enabled=true +-- !query 9 schema +struct +-- !query 9 output +spark.sql.crossJoin.enabled true + + +-- !query 10 +EXPLAIN EXTENDED select * from range(3) cross join range(3) +-- !query 10 schema +struct +-- !query 10 output +== Parsed Logical Plan == +'Project [*] ++- 'Join Cross + :- 'UnresolvedTableValuedFunction range, [3] + +- 'UnresolvedTableValuedFunction range, [3] + +== Analyzed Logical Plan == +id: bigint, id: bigint +Project [id#xL, id#xL] ++- Join Cross + :- Range (0, 3, step=1, splits=None) + +- Range (0, 3, step=1, splits=None) + +== Optimized Logical Plan == +Join Cross +:- Range (0, 3, step=1, splits=None) ++- Range (0, 3, step=1, splits=None) + +== Physical Plan == +BroadcastNestedLoopJoin BuildRight, Cross +:- *Range (0, 3, step=1, splits=2) ++- BroadcastExchange IdentityBroadcastMode + +- *Range (0, 3, step=1, splits=2) From 29281b1af00ced947173701378e9ad1b67e4924c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 10 May 2017 16:37:24 +0900 Subject: [PATCH 2/4] Apply comments --- .../spark/sql/catalyst/parser/PlanParserSuite.scala | 4 ++-- .../resources/sql-tests/inputs/inline-table.sql | 3 +++ .../sql-tests/inputs/table-valued-functions.sql | 4 ++-- .../sql-tests/results/inline-table.sql.out | 13 ++++++++++++- .../results/table-valued-functions.sql.out | 4 ++-- 5 files changed, 21 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 4c2476296c04..cf137cfdf96e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -473,11 +473,11 @@ class PlanParserSuite extends PlanTest { test("SPARK-20311 range(N) as alias") { assertEqual( - "select * from range(10) AS t", + "SELECT * FROM range(10) AS t", SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty)) .select(star())) assertEqual( - "select * from range(7) AS t(a)", + "SELECT * FROM range(7) AS t(a)", SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil)) .select(star())) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index b3ec956cd178..815820b0c944 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -49,3 +49,6 @@ select * from values ("one", count(1)), ("two", 2) as data(a, b); -- string to timestamp select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b); + +-- cross-join inline tables +SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 47ca9f1de410..f0390b03be65 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -26,5 +26,5 @@ select * from RaNgE(2); EXPLAIN select * from RaNgE(2); -- cross-join table valued functions -set spark.sql.crossJoin.enabled=true; -EXPLAIN EXTENDED select * from range(3) cross join range(3); +SET spark.sql.crossJoin.enabled=true; +EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3); diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index 4e80f0bda551..d8c6b4c8bbfe 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 18 -- !query 0 @@ -151,3 +151,14 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991- struct> -- !query 16 output 1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0] + + +-- !query 17 +SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null) +-- !query 17 schema +struct +-- !query 17 output +one 1 one 1 +one 1 three NULL +three NULL one 1 +three NULL three NULL diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index 8e4fb82bb06f..0461ba6886dc 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -106,7 +106,7 @@ struct -- !query 9 -set spark.sql.crossJoin.enabled=true +SET spark.sql.crossJoin.enabled=true -- !query 9 schema struct -- !query 9 output @@ -114,7 +114,7 @@ spark.sql.crossJoin.enabled true -- !query 10 -EXPLAIN EXTENDED select * from range(3) cross join range(3) +EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3) -- !query 10 schema struct -- !query 10 output From 54a05daaf26c9ba7228e3d9c35085c3bc6015d65 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 11 May 2017 12:39:14 +0900 Subject: [PATCH 3/4] Remove unnecessary line --- .../sql-tests/inputs/table-valued-functions.sql | 1 - .../results/table-valued-functions.sql.out | 14 +++----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index f0390b03be65..72cd8ca9d872 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -26,5 +26,4 @@ select * from RaNgE(2); EXPLAIN select * from RaNgE(2); -- cross-join table valued functions -SET spark.sql.crossJoin.enabled=true; EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3); diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index 0461ba6886dc..a8bc6faf1126 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 10 -- !query 0 @@ -106,18 +106,10 @@ struct -- !query 9 -SET spark.sql.crossJoin.enabled=true --- !query 9 schema -struct --- !query 9 output -spark.sql.crossJoin.enabled true - - --- !query 10 EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3) --- !query 10 schema +-- !query 9 schema struct --- !query 10 output +-- !query 9 output == Parsed Logical Plan == 'Project [*] +- 'Join Cross From a7a732bdbef47b4b4695843ab8005a4a85b037da Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 11 May 2017 15:49:15 +0900 Subject: [PATCH 4/4] Apply comments --- .../sql-tests/inputs/inline-table.sql | 2 +- .../sql-tests/results/inline-table.sql.out | 31 +++++++++++++++---- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index 815820b0c944..41d316444ed6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -51,4 +51,4 @@ select * from values ("one", count(1)), ("two", 2) as data(a, b); select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b); -- cross-join inline tables -SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null); +EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null); diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index d8c6b4c8bbfe..c065ce501292 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -154,11 +154,30 @@ struct> -- !query 17 -SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null) +EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null) -- !query 17 schema -struct +struct -- !query 17 output -one 1 one 1 -one 1 three NULL -three NULL one 1 -three NULL three NULL +== Parsed Logical Plan == +'Project [*] ++- 'Join Cross + :- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] + +- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] + +== Analyzed Logical Plan == +col1: string, col2: int, col1: string, col2: int +Project [col1#x, col2#x, col1#x, col2#x] ++- Join Cross + :- LocalRelation [col1#x, col2#x] + +- LocalRelation [col1#x, col2#x] + +== Optimized Logical Plan == +Join Cross +:- LocalRelation [col1#x, col2#x] ++- LocalRelation [col1#x, col2#x] + +== Physical Plan == +BroadcastNestedLoopJoin BuildRight, Cross +:- LocalTableScan [col1#x, col2#x] ++- BroadcastExchange IdentityBroadcastMode + +- LocalTableScan [col1#x, col2#x]