-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-20311][SQL] Support aliases for table value functions #17666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,19 +59,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { | |
| * A TVF maps argument lists to resolver functions that accept those arguments. Using a map | ||
| * here allows for function overloading. | ||
| */ | ||
| private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] | ||
| private type TVF = Map[ArgumentList, (UnresolvedTableValuedFunction, Seq[Any]) => LogicalPlan] | ||
|
|
||
| /** | ||
| * TVF builder. | ||
| */ | ||
| private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) | ||
| : (ArgumentList, Seq[Any] => LogicalPlan) = { | ||
| private def tvf(args: (String, DataType)*)( | ||
| pf: PartialFunction[(UnresolvedTableValuedFunction, Seq[Any]), LogicalPlan]) | ||
|
||
| : (ArgumentList, (UnresolvedTableValuedFunction, Seq[Any]) => LogicalPlan) = { | ||
| val failAnalysis: PartialFunction[(UnresolvedTableValuedFunction, Seq[Any]), LogicalPlan] = { | ||
| case (pf: UnresolvedTableValuedFunction, args: Seq[Any]) => | ||
| throw new IllegalArgumentException( | ||
| "Invalid arguments for resolved function: " + args.mkString(", ")) | ||
| } | ||
| (ArgumentList(args: _*), | ||
| pf orElse { | ||
| case args => | ||
| throw new IllegalArgumentException( | ||
| "Invalid arguments for resolved function: " + args.mkString(", ")) | ||
| }) | ||
| (tvf: UnresolvedTableValuedFunction, args: Seq[Any]) => pf.orElse(failAnalysis)(tvf, args)) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -80,37 +82,52 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { | |
| private val builtinFunctions: Map[String, TVF] = Map( | ||
| "range" -> Map( | ||
| /* range(end) */ | ||
| tvf("end" -> LongType) { case Seq(end: Long) => | ||
| Range(0, end, 1, None) | ||
| tvf("end" -> LongType) { case (tvf, args @ Seq(end: Long)) => | ||
| validateInputDimension(tvf, 1) | ||
| Range(0, end, 1, None, tvf.outputNames.headOption) | ||
|
||
| }, | ||
|
|
||
| /* range(start, end) */ | ||
| tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) => | ||
| Range(start, end, 1, None) | ||
| tvf("start" -> LongType, "end" -> LongType) { | ||
| case (tvf, args @ Seq(start: Long, end: Long)) => | ||
| validateInputDimension(tvf, 1) | ||
| Range(start, end, 1, None, tvf.outputNames.headOption) | ||
| }, | ||
|
|
||
| /* range(start, end, step) */ | ||
| tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { | ||
| case Seq(start: Long, end: Long, step: Long) => | ||
| Range(start, end, step, None) | ||
| case (tvf, args @ Seq(start: Long, end: Long, step: Long)) => | ||
| validateInputDimension(tvf, 1) | ||
| Range(start, end, step, None, tvf.outputNames.headOption) | ||
| }, | ||
|
|
||
| /* range(start, end, step, numPartitions) */ | ||
| tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, | ||
| "numPartitions" -> IntegerType) { | ||
| case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => | ||
| Range(start, end, step, Some(numPartitions)) | ||
| case (tvf, args @ Seq(start: Long, end: Long, step: Long, numPartitions: Int)) => | ||
| validateInputDimension(tvf, 1) | ||
| Range(start, end, step, Some(numPartitions), tvf.outputNames.headOption) | ||
| }) | ||
| ) | ||
|
|
||
| private def validateInputDimension(tvf: UnresolvedTableValuedFunction, expectedNumCols: Int) | ||
|
||
| : Unit = { | ||
| if (tvf.outputNames.nonEmpty) { | ||
| val numCols = tvf.outputNames.size | ||
| if (numCols != expectedNumCols) { | ||
| tvf.failAnalysis(s"expected $expectedNumCols columns but found $numCols columns") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { | ||
| case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => | ||
| builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { | ||
| case Some(tvf) => | ||
| val resolved = tvf.flatMap { case (argList, resolver) => | ||
| argList.implicitCast(u.functionArgs) match { | ||
| case Some(casted) => | ||
| Some(resolver(casted.map(_.eval()))) | ||
| Some(resolver(u, casted.map(_.eval()))) | ||
| case _ => | ||
| None | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,7 +69,10 @@ case class UnresolvedInlineTable( | |
| * select * from range(10); | ||
|
||
| * }}} | ||
| */ | ||
| case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression]) | ||
| case class UnresolvedTableValuedFunction( | ||
| functionName: String, | ||
| functionArgs: Seq[Expression], | ||
| outputNames: Seq[String]) | ||
| extends LeafNode { | ||
|
|
||
| override def output: Seq[Attribute] = Nil | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -498,12 +498,16 @@ case class Sort( | |
|
|
||
| /** Factory for constructing new `Range` nodes. */ | ||
| object Range { | ||
| def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = { | ||
| val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes | ||
|
|
||
| def apply(start: Long, end: Long, step: Long, numSlices: Option[Int], outputName: Option[String]) | ||
|
||
| : Range = { | ||
| val name = outputName.getOrElse("id") | ||
| val output = StructType(StructField(name, LongType, nullable = false) :: Nil).toAttributes | ||
| new Range(start, end, step, numSlices, output) | ||
| } | ||
|
|
||
| def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { | ||
| Range(start, end, step, Some(numSlices)) | ||
| Range(start, end, step, Some(numSlices), None) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier | |
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.parser.CatalystSqlParser | ||
| import org.apache.spark.sql.catalyst.plans.Cross | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.types._ | ||
|
|
@@ -441,4 +440,15 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { | |
|
|
||
| checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) | ||
| } | ||
|
|
||
| test("SPARK-20311 range(N) as alias") { | ||
| def rangeWithAliases(outputNames: Seq[String]): LogicalPlan = { | ||
| SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, outputNames)) | ||
|
||
| .select(star()) | ||
| } | ||
| assertAnalysisSuccess(rangeWithAliases("a" :: Nil)) | ||
| assertAnalysisError( | ||
| rangeWithAliases("a" :: "b" :: Nil), | ||
| Seq("expected 1 columns but found 2 columns")) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we should put the multi-alias in a separate rule? Since it is also used by inline table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you add the multi-alias anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed the point. okay, I'll reconsider this. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense to add this. Lets keep the multi-alias for now.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, it seems I misunderstood what you pointed out. You meant should we need to support a query like
SELECT * FROM [[tvf]] AS t(a, b, ...)in this pr? Yea, I know we currently supportrangeonly as a table value function though, I also think it'd be better to put a more general rule in this file. So, +1 for keeping this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, I'll update this pr to separate this rule and share it with the inline table rule.