Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.analysis

import java.lang.reflect.Modifier

import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -455,8 +457,17 @@ object FunctionRegistry {
private def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {

// For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main
// constructor and contains non-parameter `child` and should not be used as function builder.
val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) {
val all = tag.runtimeClass.getConstructors
val maxNumArgs = all.map(_.getParameterCount).max
all.filterNot(_.getParameterCount == maxNumArgs)
} else {
tag.runtimeClass.getConstructors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason why we originally called getDeclaredConstructor instead of getConstructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getDeclaredConstructor will find a specific constructor matching the given parameter types. Now I have some special logic about choosing the constructor, so I call getConstructors to get all the constructors.

}
// See if we can find a constructor that accepts Seq[Expression]
val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption
val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]]))
val builder = (expressions: Seq[Expression]) => {
if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
Expand All @@ -470,11 +481,8 @@ object FunctionRegistry {
} else {
// Otherwise, find a constructor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match {
case Success(e) =>
e
case Failure(e) =>
throw new AnalysisException(s"Invalid number of arguments for function $name")
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
throw new AnalysisException(s"Invalid number of arguments for function $name")
}
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2619,4 +2619,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
new URL(jarFromInvalidFs)
}
}

test("RuntimeReplaceable functions should not take extra parameters") {
val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
assert(e.message.contains("Invalid number of arguments"))
}
}