From 02e70076d42c951e9a605c31be9cdc364b652d72 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 18 Apr 2017 09:39:26 +0900 Subject: [PATCH 1/4] Support aliases for table value functions --- .../spark/sql/catalyst/parser/SqlBase.g4 | 10 ++-- .../ResolveTableValuedFunctions.scala | 51 ++++++++++++------- .../sql/catalyst/analysis/unresolved.scala | 5 +- .../sql/catalyst/parser/AstBuilder.scala | 16 +++++- .../plans/logical/basicLogicalOperators.scala | 10 ++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 12 ++++- .../sql/catalyst/parser/PlanParserSuite.scala | 13 ++++- 7 files changed, 88 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 14c511f67060..8c171cd3e6fa 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -472,11 +472,11 @@ identifierComment ; relationPrimary - : tableIdentifier sample? (AS? strictIdentifier)? #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery - | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation - | inlineTable #inlineTableDefault2 - | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction + : tableIdentifier sample? (AS? strictIdentifier)? #tableName + | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + | identifier '(' (expression (',' expression)*)? ')' (AS? identifier identifierList?)? #tableValuedFunction ; inlineTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index de6de24350f2..33fa7b614fc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -59,19 +59,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { * A TVF maps argument lists to resolver functions that accept those arguments. Using a map * here allows for function overloading. */ - private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] + private type TVF = Map[ArgumentList, (UnresolvedTableValuedFunction, Seq[Any]) => LogicalPlan] /** * TVF builder. */ - private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) - : (ArgumentList, Seq[Any] => LogicalPlan) = { + private def tvf(args: (String, DataType)*)( + pf: PartialFunction[(UnresolvedTableValuedFunction, Seq[Any]), LogicalPlan]) + : (ArgumentList, (UnresolvedTableValuedFunction, Seq[Any]) => LogicalPlan) = { + val failAnalysis: PartialFunction[(UnresolvedTableValuedFunction, Seq[Any]), LogicalPlan] = { + case (pf: UnresolvedTableValuedFunction, args: Seq[Any]) => + throw new IllegalArgumentException( + "Invalid arguments for resolved function: " + args.mkString(", ")) + } (ArgumentList(args: _*), - pf orElse { - case args => - throw new IllegalArgumentException( - "Invalid arguments for resolved function: " + args.mkString(", ")) - }) + (tvf: UnresolvedTableValuedFunction, args: Seq[Any]) => pf.orElse(failAnalysis)(tvf, args)) } /** @@ -80,29 +82,44 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { private val builtinFunctions: Map[String, TVF] = Map( "range" -> Map( /* range(end) */ - tvf("end" -> LongType) { case Seq(end: Long) => - Range(0, end, 1, None) + tvf("end" -> LongType) { case (tvf, args @ Seq(end: Long)) => + validateInputDimension(tvf, 1) + Range(0, end, 1, None, tvf.outputNames.headOption) }, /* range(start, end) */ - tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) => - Range(start, end, 1, None) + tvf("start" -> LongType, "end" -> LongType) { + case (tvf, args @ Seq(start: Long, end: Long)) => + validateInputDimension(tvf, 1) + Range(start, end, 1, None, tvf.outputNames.headOption) }, /* range(start, end, step) */ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { - case Seq(start: Long, end: Long, step: Long) => - Range(start, end, step, None) + case (tvf, args @ Seq(start: Long, end: Long, step: Long)) => + validateInputDimension(tvf, 1) + Range(start, end, step, None, tvf.outputNames.headOption) }, /* range(start, end, step, numPartitions) */ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, "numPartitions" -> IntegerType) { - case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => - Range(start, end, step, Some(numPartitions)) + case (tvf, args @ Seq(start: Long, end: Long, step: Long, numPartitions: Int)) => + validateInputDimension(tvf, 1) + Range(start, end, step, Some(numPartitions), tvf.outputNames.headOption) }) ) + private def validateInputDimension(tvf: UnresolvedTableValuedFunction, expectedNumCols: Int) + : Unit = { + if (tvf.outputNames.nonEmpty) { + val numCols = tvf.outputNames.size + if (numCols != expectedNumCols) { + tvf.failAnalysis(s"expected $expectedNumCols columns but found $numCols columns") + } + } + } + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { @@ -110,7 +127,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { case Some(casted) => - Some(resolver(casted.map(_.eval()))) + Some(resolver(u, casted.map(_.eval()))) case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 262b894e2a0a..397ee6a44e4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -69,7 +69,10 @@ case class UnresolvedInlineTable( * select * from range(10); * }}} */ -case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression]) +case class UnresolvedTableValuedFunction( + functionName: String, + functionArgs: Seq[Expression], + outputNames: Seq[String]) extends LeafNode { override def output: Seq[Attribute] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2a9b4a9a9f5..a0565dac3bbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -687,7 +687,21 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTableValuedFunction(ctx: TableValuedFunctionContext) : LogicalPlan = withOrigin(ctx) { - UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) + val (tvfName, aliasNameOption) = ctx.identifier.asScala.map(_.getText) match { + case Seq(funcName, aliasName) => (funcName, Some(aliasName)) + case Seq(funcName) => (funcName, None) + } + val outputNames = if (ctx.identifierList != null) { + visitIdentifierList(ctx.identifierList) + } else { + Seq.empty + } + val plan = UnresolvedTableValuedFunction( + tvfName, ctx.expression.asScala.map(expression), outputNames) + aliasNameOption match { + case Some(aliasName) => SubqueryAlias(aliasName, plan) + case _ => plan + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f663d7b8a8f7..285ecb26d258 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -498,12 +498,16 @@ case class Sort( /** Factory for constructing new `Range` nodes. */ object Range { - def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = { - val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes + + def apply(start: Long, end: Long, step: Long, numSlices: Option[Int], outputName: Option[String]) + : Range = { + val name = outputName.getOrElse("id") + val output = StructType(StructField(name, LongType, nullable = false) :: Nil).toAttributes new Range(start, end, step, numSlices, output) } + def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { - Range(start, end, step, Some(numSlices)) + Range(start, end, step, Some(numSlices), None) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 893bb1b74cea..fc5c16cfe46c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -441,4 +440,15 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } + + test("SPARK-20311 range(N) as alias") { + def rangeWithAliases(outputNames: Seq[String]): LogicalPlan = { + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, outputNames)) + .select(star()) + } + assertAnalysisSuccess(rangeWithAliases("a" :: Nil)) + assertAnalysisError( + rangeWithAliases("a" :: "b" :: Nil), + Seq("expected 1 columns but found 2 columns")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 411777d6e85a..4c2476296c04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest { test("table valued function") { assertEqual( "select * from range(2)", - UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star())) + UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star())) + } + + test("SPARK-20311 range(N) as alias") { + assertEqual( + "select * from range(10) AS t", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty)) + .select(star())) + assertEqual( + "select * from range(7) AS t(a)", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil)) + .select(star())) } test("inline table") { From 50fc51929a76f5199091bd1d61d1eb2589031fff Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 19 Apr 2017 00:37:43 +0900 Subject: [PATCH 2/4] Add a rule for table aliases --- .../spark/sql/catalyst/parser/SqlBase.g4 | 16 ++++++++----- .../sql/catalyst/parser/AstBuilder.scala | 24 +++++++------------ 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 8c171cd3e6fa..9529216543e1 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -472,15 +472,19 @@ identifierComment ; relationPrimary - : tableIdentifier sample? (AS? strictIdentifier)? #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery - | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation - | inlineTable #inlineTableDefault2 - | identifier '(' (expression (',' expression)*)? ')' (AS? identifier identifierList?)? #tableValuedFunction + : tableIdentifier sample? (AS? strictIdentifier)? #tableName + | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + | identifier '(' (expression (',' expression)*)? ')' tableAlias #tableValuedFunction ; inlineTable - : VALUES expression (',' expression)* (AS? identifier identifierList?)? + : VALUES expression (',' expression)* tableAlias + ; + +tableAlias + : (AS? identifier identifierList?)? ; rowFormat diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a0565dac3bbb..3e24f7d405ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -687,21 +687,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTableValuedFunction(ctx: TableValuedFunctionContext) : LogicalPlan = withOrigin(ctx) { - val (tvfName, aliasNameOption) = ctx.identifier.asScala.map(_.getText) match { - case Seq(funcName, aliasName) => (funcName, Some(aliasName)) - case Seq(funcName) => (funcName, None) - } - val outputNames = if (ctx.identifierList != null) { - visitIdentifierList(ctx.identifierList) + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) } else { Seq.empty } - val plan = UnresolvedTableValuedFunction( - tvfName, ctx.expression.asScala.map(expression), outputNames) - aliasNameOption match { - case Some(aliasName) => SubqueryAlias(aliasName, plan) - case _ => plan - } + + val tvf = UnresolvedTableValuedFunction( + ctx.identifier.getText, ctx.expression.asScala.map(expression), aliases) + tvf.optionalMap(ctx.tableAlias.identifier)(aliasPlan) } /** @@ -719,14 +713,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } - val aliases = if (ctx.identifierList != null) { - visitIdentifierList(ctx.identifierList) + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) } else { Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } val table = UnresolvedInlineTable(aliases, rows) - table.optionalMap(ctx.identifier)(aliasPlan) + table.optionalMap(ctx.tableAlias.identifier)(aliasPlan) } /** From 399d823b13719c7623013164f284cb6668c88218 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 9 May 2017 09:44:04 +0900 Subject: [PATCH 3/4] Apply comments --- .../spark/sql/catalyst/parser/SqlBase.g4 | 14 ++-- .../ResolveTableValuedFunctions.scala | 73 +++++++++---------- .../sql/catalyst/analysis/unresolved.scala | 5 +- .../sql/catalyst/parser/AstBuilder.scala | 9 ++- .../plans/logical/basicLogicalOperators.scala | 10 +-- 5 files changed, 57 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 9529216543e1..41daf58a98fd 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -472,17 +472,21 @@ identifierComment ; relationPrimary - : tableIdentifier sample? (AS? strictIdentifier)? #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery - | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation - | inlineTable #inlineTableDefault2 - | identifier '(' (expression (',' expression)*)? ')' tableAlias #tableValuedFunction + : tableIdentifier sample? (AS? strictIdentifier)? #tableName + | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction ; inlineTable : VALUES expression (',' expression)* tableAlias ; +functionTable + : identifier '(' (expression (',' expression)*)? ')' tableAlias + ; + tableAlias : (AS? identifier identifierList?)? ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 33fa7b614fc5..dad1340571cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} @@ -59,21 +59,19 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { * A TVF maps argument lists to resolver functions that accept those arguments. Using a map * here allows for function overloading. */ - private type TVF = Map[ArgumentList, (UnresolvedTableValuedFunction, Seq[Any]) => LogicalPlan] + private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] /** * TVF builder. */ - private def tvf(args: (String, DataType)*)( - pf: PartialFunction[(UnresolvedTableValuedFunction, Seq[Any]), LogicalPlan]) - : (ArgumentList, (UnresolvedTableValuedFunction, Seq[Any]) => LogicalPlan) = { - val failAnalysis: PartialFunction[(UnresolvedTableValuedFunction, Seq[Any]), LogicalPlan] = { - case (pf: UnresolvedTableValuedFunction, args: Seq[Any]) => - throw new IllegalArgumentException( - "Invalid arguments for resolved function: " + args.mkString(", ")) - } + private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) + : (ArgumentList, Seq[Any] => LogicalPlan) = { (ArgumentList(args: _*), - (tvf: UnresolvedTableValuedFunction, args: Seq[Any]) => pf.orElse(failAnalysis)(tvf, args)) + pf orElse { + case args => + throw new IllegalArgumentException( + "Invalid arguments for resolved function: " + args.mkString(", ")) + }) } /** @@ -82,52 +80,37 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { private val builtinFunctions: Map[String, TVF] = Map( "range" -> Map( /* range(end) */ - tvf("end" -> LongType) { case (tvf, args @ Seq(end: Long)) => - validateInputDimension(tvf, 1) - Range(0, end, 1, None, tvf.outputNames.headOption) + tvf("end" -> LongType) { case Seq(end: Long) => + Range(0, end, 1, None) }, /* range(start, end) */ - tvf("start" -> LongType, "end" -> LongType) { - case (tvf, args @ Seq(start: Long, end: Long)) => - validateInputDimension(tvf, 1) - Range(start, end, 1, None, tvf.outputNames.headOption) + tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) => + Range(start, end, 1, None) }, /* range(start, end, step) */ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { - case (tvf, args @ Seq(start: Long, end: Long, step: Long)) => - validateInputDimension(tvf, 1) - Range(start, end, step, None, tvf.outputNames.headOption) + case Seq(start: Long, end: Long, step: Long) => + Range(start, end, step, None) }, /* range(start, end, step, numPartitions) */ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, "numPartitions" -> IntegerType) { - case (tvf, args @ Seq(start: Long, end: Long, step: Long, numPartitions: Int)) => - validateInputDimension(tvf, 1) - Range(start, end, step, Some(numPartitions), tvf.outputNames.headOption) + case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => + Range(start, end, step, Some(numPartitions)) }) ) - private def validateInputDimension(tvf: UnresolvedTableValuedFunction, expectedNumCols: Int) - : Unit = { - if (tvf.outputNames.nonEmpty) { - val numCols = tvf.outputNames.size - if (numCols != expectedNumCols) { - tvf.failAnalysis(s"expected $expectedNumCols columns but found $numCols columns") - } - } - } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { + val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { case Some(casted) => - Some(resolver(u, casted.map(_.eval()))) + Some(resolver(casted.map(_.eval()))) case _ => None } @@ -142,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") } + + // If alias names assigned, add `Project` with the aliases + if (u.outputNames.nonEmpty) { + val outputAttrs = resolvedFunc.output + // Checks if the number of the aliases is equal to expected one + if (u.outputNames.size != outputAttrs.size) { + u.failAnalysis(s"expected ${outputAttrs.size} columns but " + + s"found ${u.outputNames.size} columns") + } + val aliases = outputAttrs.zip(u.outputNames).map { + case (attr, name) => Alias(attr, name)() + } + Project(aliases, resolvedFunc) + } else { + resolvedFunc + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 397ee6a44e4e..51bef6e20b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -66,7 +66,10 @@ case class UnresolvedInlineTable( /** * A table-valued function, e.g. * {{{ - * select * from range(10); + * select id from range(10); + * + * // Assign alias names + * select t.a from range(10) t(a); * }}} */ case class UnresolvedTableValuedFunction( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3e24f7d405ee..e03fe2ccb8d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -687,15 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTableValuedFunction(ctx: TableValuedFunctionContext) : LogicalPlan = withOrigin(ctx) { - val aliases = if (ctx.tableAlias.identifierList != null) { - visitIdentifierList(ctx.tableAlias.identifierList) + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) } else { Seq.empty } val tvf = UnresolvedTableValuedFunction( - ctx.identifier.getText, ctx.expression.asScala.map(expression), aliases) - tvf.optionalMap(ctx.tableAlias.identifier)(aliasPlan) + func.identifier.getText, func.expression.asScala.map(expression), aliases) + tvf.optionalMap(func.tableAlias.identifier)(aliasPlan) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 285ecb26d258..f663d7b8a8f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -498,16 +498,12 @@ case class Sort( /** Factory for constructing new `Range` nodes. */ object Range { - - def apply(start: Long, end: Long, step: Long, numSlices: Option[Int], outputName: Option[String]) - : Range = { - val name = outputName.getOrElse("id") - val output = StructType(StructField(name, LongType, nullable = false) :: Nil).toAttributes + def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = { + val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes new Range(start, end, step, numSlices, output) } - def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { - Range(start, end, step, Some(numSlices), None) + Range(start, end, step, Some(numSlices)) } } From 81bef3ba21cb0c3e4b36f3fc492d9ab3a3124829 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 9 May 2017 11:05:49 +0900 Subject: [PATCH 4/4] Add more test cases --- .../spark/sql/catalyst/analysis/AnalysisSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index fc5c16cfe46c..31047f688600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -442,13 +442,15 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { } test("SPARK-20311 range(N) as alias") { - def rangeWithAliases(outputNames: Seq[String]): LogicalPlan = { - SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, outputNames)) + def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = { + SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames)) .select(star()) } - assertAnalysisSuccess(rangeWithAliases("a" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil)) assertAnalysisError( - rangeWithAliases("a" :: "b" :: Nil), + rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil), Seq("expected 1 columns but found 2 columns")) } }