Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Mon Aug 15 23:06:33 PDT 2016
  • Loading branch information
ericl committed Aug 16, 2016
commit 2f80f549dd3d765fd3fdc63f9795d8e5562e38fa
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}

/**
* Rule that resolves table-valued function references.
Expand All @@ -32,23 +34,26 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
/**
* List of argument names and their types, used to declare a function.
*/
private case class ArgumentList(args: (String, Class[_])*) {
private case class ArgumentList(args: (String, DataType)*) {
/**
* @return whether this list is assignable from the given sequence of values.
* Try to cast the expressions to satisfy the expected types of this argument list. If there
* are any types that cannot be casted, then None is returned.
*/
def assignableFrom(values: Seq[Any]): Boolean = {
def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = {
if (args.length == values.length) {
args.zip(values).forall { case ((name, clazz), value) =>
clazz.isAssignableFrom(value.getClass)
val casted = values.zip(args).map { case (value, (_, expectedType)) =>
TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType)
}
if (casted.forall(_.isDefined)) {
return Some(casted.map(_.get))
}
} else {
false
}
None
}

override def toString: String = {
args.map { a =>
s"${a._1}: ${a._2.getSimpleName}"
s"${a._1}: ${a._2.typeName}"
}.mkString(", ")
}
}
Expand All @@ -65,45 +70,43 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
private val builtinFunctions: Map[String, TVF] = Map(
"range" -> Map(
/* range(end) */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this a bit more concise by using a combination of a builder and partial function. For example:

// Builder
def tvf(params: (String, DataType)*)(pf: PartialFunction[Seq[Expression], LogicalPlan]): TVF = (ArgumentList(params: _*), pf)

// Use
private val builtinFunctions: Map[String, TVF] = Map(
  "range" -> Map(
    /* range(end) */
    tvf("end" -> LongType) { case Seq(end: Long) =>
      Range(0, end, 1, defaultParallelism)
    },
    /* range(start, end) */
    tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) =>
      Range(start, end, 1, defaultParallelism)
    }
    /* ... */)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems nice. Updated.

ArgumentList(("end", classOf[Number])) -> (
ArgumentList(("end", LongType)) -> (
(args: Seq[Any]) =>
Range(0, args(0).asInstanceOf[Number].longValue, 1, defaultParallelism)),
Range(0, args(0).asInstanceOf[Long], 1, defaultParallelism)),

/* range(start, end) */
ArgumentList(("start", classOf[Number]), ("end", classOf[Number])) -> (
ArgumentList(("start", LongType), ("end", LongType)) -> (
(args: Seq[Any]) =>
Range(
args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, 1,
defaultParallelism)),
args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], 1, defaultParallelism)),

/* range(start, end, step) */
ArgumentList(("start", classOf[Number]), ("end", classOf[Number]),
("step", classOf[Number])) -> (
ArgumentList(("start", LongType), ("end", LongType), ("step", LongType)) -> (
(args: Seq[Any]) =>
Range(
args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue,
args(2).asInstanceOf[Number].longValue, defaultParallelism)),
args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], args(2).asInstanceOf[Long],
defaultParallelism)),

/* range(start, end, step, numPartitions) */
ArgumentList(("start", classOf[Number]), ("end", classOf[Number]),
("step", classOf[Number]), ("numPartitions", classOf[Integer])) -> (
ArgumentList(("start", LongType), ("end", LongType), ("step", LongType),
("numPartitions", IntegerType)) -> (
(args: Seq[Any]) =>
Range(
args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue,
args(2).asInstanceOf[Number].longValue, args(3).asInstanceOf[Integer])))
args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], args(2).asInstanceOf[Long],
args(3).asInstanceOf[Integer])))
)

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedTableValuedFunction =>
builtinFunctions.get(u.functionName) match {
case Some(tvf) =>
val evaluatedArgs = u.functionArgs.map(_.eval())
for ((argList, resolver) <- tvf) {
if (argList.assignableFrom(evaluatedArgs)) {
return resolver(evaluatedArgs)
val casted = argList.implicitCast(u.functionArgs)
if (casted.isDefined) {
return resolver(casted.get.map(_.eval()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smallish: Do you think we can avoid a non-local return?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flatMap it is

}
}
val argTypes = evaluatedArgs.map(_.getClass.getSimpleName).mkString(", ")
val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ")
u.failAnalysis(
s"""error: table-valued function ${u.functionName} with alternatives:
|${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ select * from range(0, 10, 2);
select * from range(0, 10, 1, 200);

-- range call error
select * from range(2, 'x');
select * from range(1, 1, 1, 1, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also test nulls?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ struct<id:bigint>


-- !query 5
select * from range(2, 'x')
select * from range(1, 1, 1, 1, 1)
-- !query 5 schema
struct<>
-- !query 5 output
org.apache.spark.sql.AnalysisException
error: table-valued function range with alternatives:
(end: Number)
(start: Number, end: Number)
(start: Number, end: Number, step: Number)
(start: Number, end: Number, step: Number, numPartitions: Integer)
cannot be applied to: (Integer, UTF8String); line 1 pos 14
(end: long)
(start: long, end: long)
(start: long, end: long, step: long)
(start: long, end: long, step: long, numPartitions: integer)
cannot be applied to: (integer, integer, integer, integer, integer); line 1 pos 14