-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17069] Expose spark.range() as table-valued function in SQL #14656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
5d7fbfb
e5b0e27
0d4ad5e
4d9eb53
21078e6
cc987d3
92e668f
d1fd2b9
fbf515f
29f9538
e9251f5
1f81b97
94ad7f1
a444b9e
7bac3af
78c9f05
4d94ac0
2f80f54
8e03f51
1fa57c4
f8831ca
7ebd563
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,9 +18,11 @@ | |
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import org.apache.spark.{SparkConf, SparkContext} | ||
| import org.apache.spark.sql.catalyst.expressions.Expression | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} | ||
| import org.apache.spark.sql.catalyst.rules._ | ||
| import org.apache.spark.sql.types.{DataType, IntegerType, LongType} | ||
|
|
||
| /** | ||
| * Rule that resolves table-valued function references. | ||
|
|
@@ -32,23 +34,26 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { | |
| /** | ||
| * List of argument names and their types, used to declare a function. | ||
| */ | ||
| private case class ArgumentList(args: (String, Class[_])*) { | ||
| private case class ArgumentList(args: (String, DataType)*) { | ||
| /** | ||
| * @return whether this list is assignable from the given sequence of values. | ||
| * Try to cast the expressions to satisfy the expected types of this argument list. If there | ||
| * are any types that cannot be casted, then None is returned. | ||
| */ | ||
| def assignableFrom(values: Seq[Any]): Boolean = { | ||
| def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = { | ||
| if (args.length == values.length) { | ||
| args.zip(values).forall { case ((name, clazz), value) => | ||
| clazz.isAssignableFrom(value.getClass) | ||
| val casted = values.zip(args).map { case (value, (_, expectedType)) => | ||
| TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType) | ||
| } | ||
| if (casted.forall(_.isDefined)) { | ||
| return Some(casted.map(_.get)) | ||
| } | ||
| } else { | ||
| false | ||
| } | ||
| None | ||
| } | ||
|
|
||
| override def toString: String = { | ||
| args.map { a => | ||
| s"${a._1}: ${a._2.getSimpleName}" | ||
| s"${a._1}: ${a._2.typeName}" | ||
| }.mkString(", ") | ||
| } | ||
| } | ||
|
|
@@ -65,45 +70,43 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { | |
| private val builtinFunctions: Map[String, TVF] = Map( | ||
| "range" -> Map( | ||
| /* range(end) */ | ||
| ArgumentList(("end", classOf[Number])) -> ( | ||
| ArgumentList(("end", LongType)) -> ( | ||
| (args: Seq[Any]) => | ||
| Range(0, args(0).asInstanceOf[Number].longValue, 1, defaultParallelism)), | ||
| Range(0, args(0).asInstanceOf[Long], 1, defaultParallelism)), | ||
|
|
||
| /* range(start, end) */ | ||
| ArgumentList(("start", classOf[Number]), ("end", classOf[Number])) -> ( | ||
| ArgumentList(("start", LongType), ("end", LongType)) -> ( | ||
| (args: Seq[Any]) => | ||
| Range( | ||
| args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, 1, | ||
| defaultParallelism)), | ||
| args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], 1, defaultParallelism)), | ||
|
|
||
| /* range(start, end, step) */ | ||
| ArgumentList(("start", classOf[Number]), ("end", classOf[Number]), | ||
| ("step", classOf[Number])) -> ( | ||
| ArgumentList(("start", LongType), ("end", LongType), ("step", LongType)) -> ( | ||
| (args: Seq[Any]) => | ||
| Range( | ||
| args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, | ||
| args(2).asInstanceOf[Number].longValue, defaultParallelism)), | ||
| args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], args(2).asInstanceOf[Long], | ||
| defaultParallelism)), | ||
|
|
||
| /* range(start, end, step, numPartitions) */ | ||
| ArgumentList(("start", classOf[Number]), ("end", classOf[Number]), | ||
| ("step", classOf[Number]), ("numPartitions", classOf[Integer])) -> ( | ||
| ArgumentList(("start", LongType), ("end", LongType), ("step", LongType), | ||
| ("numPartitions", IntegerType)) -> ( | ||
| (args: Seq[Any]) => | ||
| Range( | ||
| args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, | ||
| args(2).asInstanceOf[Number].longValue, args(3).asInstanceOf[Integer]))) | ||
| args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], args(2).asInstanceOf[Long], | ||
| args(3).asInstanceOf[Integer]))) | ||
| ) | ||
|
|
||
| override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { | ||
| case u: UnresolvedTableValuedFunction => | ||
| builtinFunctions.get(u.functionName) match { | ||
| case Some(tvf) => | ||
| val evaluatedArgs = u.functionArgs.map(_.eval()) | ||
| for ((argList, resolver) <- tvf) { | ||
| if (argList.assignableFrom(evaluatedArgs)) { | ||
| return resolver(evaluatedArgs) | ||
| val casted = argList.implicitCast(u.functionArgs) | ||
| if (casted.isDefined) { | ||
| return resolver(casted.get.map(_.eval())) | ||
|
||
| } | ||
| } | ||
| val argTypes = evaluatedArgs.map(_.getClass.getSimpleName).mkString(", ") | ||
| val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") | ||
| u.failAnalysis( | ||
| s"""error: table-valued function ${u.functionName} with alternatives: | ||
| |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,4 +14,4 @@ select * from range(0, 10, 2); | |
| select * from range(0, 10, 1, 200); | ||
|
|
||
| -- range call error | ||
| select * from range(2, 'x'); | ||
| select * from range(1, 1, 1, 1, 1); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you also test nulls?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make this a bit more concise by using a combination of a builder and partial function. For example:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems nice. Updated.