Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
Ngone51 committed Mar 17, 2020
commit dd902989815f2b9a5fceebfbe6b78591524e4c60
Original file line number Diff line number Diff line change
Expand Up @@ -581,11 +581,12 @@ object ScalaReflection extends ScalaReflection {
* Note that it only works for scala classes with primary constructor, and currently doesn't
* support inner class.
*/
def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = {
// FIXME(wuyi): test on inner class/repl
def getConstructorParameters(cls: Class[_]): Seq[Class[_]] = {
val m = runtimeMirror(cls.getClassLoader)
val classSymbol = m.staticClass(cls.getName)
val t = classSymbol.selfType
getConstructorParameters(t)
getConstructorParameters(t).map { case (_, tpe) => getClassFromType(tpe)}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2768,7 +2768,7 @@ class Analyzer(
input
}
}
// FIXME(wuyi): why applied 2 times?
// assign Nil inputCaseClass to avoid applying this rule for multiple times
udf.copy(children = newInputs, inputCaseClass = Nil)
} else {
udf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType}
* UDF return null if there is any null input value of these types. On the
* other hand, Java UDFs can only have boxed types, thus this parameter will
* always be all false.
* @param inputCaseClass Includes the Class[_] of case classes from the input parameter.
* If the input parameter is not a case class, then, the corresponding value
* is None.
* @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do
* not want to perform coercion, simply use "Nil". Note that it would've been
* better to use Option of Seq[DataType] so we can use "None" as the case for no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -448,8 +449,20 @@ case class NewInstance(
childrenResolved && !needOuterPointer
}

private def argConverters(): Seq[Any => Any] = {
val inputTypes = ScalaReflection.expressionJavaClasses(arguments)
val neededTypes = ScalaReflection.getConstructorParameters(cls)
arguments.zip(inputTypes).zip(neededTypes).map { case ((arg, input), needed) =>
if (needed.isAssignableFrom(input)) {
identity[Any] _
} else {
CatalystTypeConverters.createToScalaConverter(arg.dataType)
}
}
}

@transient private lazy val constructor: (Seq[AnyRef]) => Any = {
val paramTypes = ScalaReflection.expressionJavaClasses(arguments)
val paramTypes = ScalaReflection.getConstructorParameters(cls)
val getConstructor = (paramClazz: Seq[Class[_]]) => {
ScalaReflection.findConstructor(cls, paramClazz).getOrElse {
sys.error(s"Couldn't find a valid constructor on $cls")
Expand All @@ -472,6 +485,10 @@ case class NewInstance(

override def eval(input: InternalRow): Any = {
val argValues = arguments.map(_.eval(input))
.zip(argConverters())
.map { case (arg, converter) =>
converter(arg)
}
constructor(argValues.map(_.asInstanceOf[AnyRef]))
}

Expand All @@ -480,6 +497,20 @@ case class NewInstance(

val (argCode, argString, resultIsNull) = prepareArguments(ctx)

val converterClassName = classOf[Any => Any].getName
val convertersTerm = ctx.addReferenceObj(
"converters", argConverters().toArray, s"$converterClassName[]")
val argTypes = ScalaReflection.getConstructorParameters(cls)
val convertedArgs = argTypes.map { a =>
ctx.addMutableState(CodeGenerator.boxedType(a.getSimpleName), "convertedArg")
}
val convertedCode = argString.split(",").zip(argTypes).zipWithIndex.map {
case ((arg, tpe), i) =>
s"${convertedArgs(i)} = " +
s"(${CodeGenerator.boxedType(tpe.getSimpleName)}) $convertersTerm[$i].apply($arg);"
}.mkString("\n")
val convertedArgString = convertedArgs.mkString(",")

val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))

ev.isNull = resultIsNull
Expand All @@ -488,16 +519,17 @@ case class NewInstance(
// If there are no constructors, the `new` method will fail. In
// this case we can try to call the apply method constructor
// that might be defined on the companion object.
case 0 => s"$className$$.MODULE$$.apply($argString)"
case 0 => s"$className$$.MODULE$$.apply($convertedArgString)"
case _ => outer.map { gen =>
s"${gen.value}.new ${cls.getSimpleName}($argString)"
s"${gen.value}.new ${cls.getSimpleName}($convertedArgString)"
}.getOrElse {
s"new $className($argString)"
s"new $className($convertedArgString)"
}
}

val code = code"""
$argCode
$convertedCode
${outer.map(_.code).getOrElse("")}
final $javaType ${ev.value} = ${ev.isNull} ?
${CodeGenerator.defaultValue(dataType)} : $constructorCall;
Expand Down
31 changes: 25 additions & 6 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.QueryExecutionListener

private case class Person(age: Int)

private case class FunctionResult(f1: String, f2: String)

class UDFSuite extends QueryTest with SharedSparkSession {
Expand Down Expand Up @@ -553,10 +551,31 @@ class UDFSuite extends QueryTest with SharedSparkSession {
assert(e.getMessage.contains("Invalid arguments for function cast"))
}

test("SPARK-30127: Support input case class in typed Scala UDF") {
val f = (p: Person) => p.age
test("only one case class parameter") {
val f = (d: TestData) => d.key * d.value.toInt
val myUdf = udf(f)
val df = Seq(("data", TestData(50, "2"))).toDF("col1", "col2")
checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
}

test("one case class with primitive parameter") {
val f = (i: Int, p: TestData) => p.key * i
val myUdf = udf(f)
val df = Seq((2, TestData(50, "data"))).toDF("col1", "col2")
checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(100) :: Nil)
}

test("multiple case class parameters") {
val f = (d1: TestData, d2: TestData) => d1.key * d2.key
val myUdf = udf(f)
val df = Seq((TestData(10, "d1"), TestData(50, "d2"))).toDF("col1", "col2")
checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(500) :: Nil)
}

test("input case class parameter and return case class ") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test nested case calss as well?

val f = (d1: TestData) => TestData(d1.key * 2, "copy")
val myUdf = udf(f)
val df = Seq(("Jack", Person(50))).toDF("name", "age")
checkAnswer(df.select(myUdf(Column("age"))), Row(50) :: Nil)
val df = Seq(("data", TestData(50, "d2"))).toDF("col1", "col2")
checkAnswer(df.select(myUdf(Column("col2"))), Row(Row(100, "copy")) :: Nil)
}
}