Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Mon Aug 15 00:43:07 PDT 2016
  • Loading branch information
ericl committed Aug 15, 2016
commit e5b0e2747e7e0bf254f173fbc0752e2e47266918
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ relationPrimary
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
| identifier '(' expression (',' expression)* ')' #tableValuedFunction
| identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction
;

inlineTable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,99 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules._

/**
* Rule for resolving references to table-valued functions. Currently this only resolves
* references to the hard-coded range() operator.
* Rule that resolves table-valued function references.
*/
object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
private def defaultParallelism: Int = 200 // TODO(ekl) fix
private lazy val defaultParallelism =
SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedTableValuedFunction =>
// TODO(ekl) we should have a tvf registry
if (u.functionName != "range") {
u.failAnalysis(s"could not resolve `${u.functionName}` to a table valued function")
/**
* Type aliases for a TVF declaration. A TVF maps a sequence of named arguments to a function
* resolving the TVF given a matching list of arguments from the user. This allows for
* function overloading (e.g. range(100), range(0, 100)).
*/
private type NamedArguments = Seq[Tuple2[String, Class[_]]]
private type TVF = Map[NamedArguments, Seq[Any] => LogicalPlan]

/**
* Internal registry of table-valued-functions. TODO(ekl) we should have a proper registry
*/
private val builtinFunctions: Map[String, TVF] = Map(
"range" -> Map(
/* range(end) */
Seq(("end", classOf[Number])) -> (
(args: Seq[Any]) =>
Range(0, args(0).asInstanceOf[Number].longValue, 1, defaultParallelism)),

/* range(start, end) */
Seq(("start", classOf[Number]), ("end", classOf[Number])) -> (
(args: Seq[Any]) =>
Range(
args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, 1,
defaultParallelism)),

/* range(start, end, step) */
Seq(("start", classOf[Number]), ("end", classOf[Number]), ("steps", classOf[Number])) -> (
(args: Seq[Any]) =>
Range(
args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue,
args(2).asInstanceOf[Number].longValue, defaultParallelism)),

/* range(start, end, step, numPartitions) */
Seq(("start", classOf[Number]), ("end", classOf[Number]), ("steps", classOf[Number]),
("numPartitions", classOf[Integer])) -> (
(args: Seq[Any]) =>
Range(
args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue,
args(2).asInstanceOf[Number].longValue, args(3).asInstanceOf[Integer]))
)
)

/**
* Returns whether a given sequence of values can be assigned to the specified arguments.
*/
private def assignableFrom(args: NamedArguments, values: Seq[Any]): Boolean = {
if (args.length == values.length) {
args.zip(values).forall { case ((name, clazz), value) =>
clazz.isAssignableFrom(value.getClass)
}
val evaluatedArgs = u.functionArgs.map(_.eval())
val longArgs = evaluatedArgs.map(_.toString.toLong) // TODO(ekl) fix
longArgs match {
case Seq(end) =>
Range(0, end, 1, defaultParallelism)
case Seq(start, end) =>
Range(start, end, 1, defaultParallelism)
case Seq(start, end, step) =>
Range(start, end, step, defaultParallelism)
case Seq(start, end, step, numPartitions) =>
Range(start, end, step, numPartitions.toInt)
} else {
false
}
}

/**
* Formats a list of named args, e.g. to "start: Number, end: Number, steps: Number".
*/
private def formatArgs(args: NamedArguments): String = {
args.map { a =>
s"${a._1}: ${a._2.getSimpleName}"
}.mkString(", ")
}

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedTableValuedFunction =>
builtinFunctions.get(u.functionName) match {
case Some(tvf) =>
val evaluatedArgs = u.functionArgs.map(_.eval())
for ((argSpec, resolver) <- tvf) {
if (assignableFrom(argSpec, evaluatedArgs)) {
return resolver(evaluatedArgs)
}
}
val argTypes = evaluatedArgs.map(_.getClass.getSimpleName).mkString(", ")
u.failAnalysis(
s"""error: table-valued function ${u.functionName} with alternatives:
|${tvf.keys.map(formatArgs).toSeq.sorted.map(x => s" ($x)").mkString("\n")}
|cannot be applied to: (${argTypes})""".stripMargin)
case _ =>
u.failAnalysis(s"invalid number of argument for range(): ${longArgs}")
u.failAnalysis(s"could not resolve `${u.functionName}` to a table valued function")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class UnresolvedRelation(
}

/**
* Holds a table-valued-function call that has yet to be resolved.
* Holds a table-valued function call that has yet to be resolved.
*/
case class UnresolvedTableValuedFunction(
functionName: String, functionArgs: Seq[Expression]) extends LeafNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,13 +653,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}

/**
* Create a table-valued-function call with arguments, e.g. range(1000)
* Create a table-valued function call with arguments, e.g. range(1000)
*/
override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit style I think we need to put the every arguments on a separate line if the unbroken line exceeds 100 characters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a fine style actually.

Copy link
Contributor

@hvanhovell hvanhovell Aug 17, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevermind then :)

: LogicalPlan = withOrigin(ctx) {
val expressions = ctx.expression.asScala.map { ec =>
val e = expression(ec)
assert(e.foldable, "All params of a table-valued-function call must be constants.", ec)
assert(e.foldable, "All arguments of a table-valued-function must be constants.", ec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not part of this pr -- @hvanhovell can we rename this assert? It is actually not the normal assert that only fails when there is a bug in Spark. This is used to catch user code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to this pr, should this be checked in analysis time rather than expression time? Otherwise you wouldn't be able to do 1 + 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parser will actually support 1 + 2 because it has built in support for symbolic functions. How about we defer this until CheckAnalysis? It is sometimes nice to use more complex expressions.

Copy link
Contributor Author

@ericl ericl Aug 16, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean defer evaluation of the arguments until after analysis? Otherwise, it would presumably crash during resolution when eval is called if this check is moved.

Copy link
Contributor

@hvanhovell hvanhovell Aug 16, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, lets defer evaluation until you need them in ResolveTableValuedFunctions. You will need to check if they are resolved though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #14665 to rename the assert.

e
}
UnresolvedTableValuedFunction(ctx.identifier.getText, expressions)
Expand Down