Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
Ngone51 committed Jun 18, 2020
commit 1c82558871f0c1f4eff0b08e809df5e862e0892d
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ case class ScalaUDF(
private def scalaConverter(i: Int, dataType: DataType): Any => Any = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep the name unchanged? If we keep using createToScalaConverter, many diff can be avoided.

if (inputEncoders.isEmpty || // for untyped Scala UDF
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also for Java UDF

inputEncoders(i).isEmpty || // for types aren't supported by encoder, e.g. Any
isPrimitive(dataType) ||
dataType.isInstanceOf[UserDefinedType[_]]) {
inputPrimitives(i) || // inputPrimitives is not empty when inputEncoders is not empty
dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) {
createToScalaConverter(dataType)
} else {
val enc = inputEncoders(i).get
Expand Down Expand Up @@ -1065,14 +1065,32 @@ case class ScalaUDF(
val (funcArgs, initArgs) = evals.zipWithIndex.zip(children.map(_.dataType)).map {
case ((eval, i), dt) =>
val argTerm = ctx.freshName("arg")
val initArg = if (isPrimitive(dt)) {
// Check `inputPrimitives` when it's not empty in order to figure out the Option
// type as non primitive type, e.g., Option[Int]. Fall back to `isPrimitive` when
// `inputPrimitives` is empty for other cases, e.g., Java UDF, untyped Scala UDF
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so untyped Scala UDF doesn't support Option?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea. We require the encoder to support Option but untyped Scala UDF can't provide the encoder.

val primitive = (inputPrimitives.isEmpty && isPrimitive(dt)) ||
(inputPrimitives.nonEmpty && inputPrimitives(i))
val initArg = if (primitive) {
val convertedTerm = ctx.freshName("conv")
s"""
|${CodeGenerator.boxedType(dt)} $convertedTerm = ${eval.value};
|Object $argTerm = ${eval.isNull} ? null : $convertedTerm;
""".stripMargin
} else {
s"Object $argTerm = ${eval.isNull} ? null : $convertersTerm[$i].apply(${eval.value});"
s"""
|Object $argTerm = null;
|// handle the top level Option type specifically
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's special for top-level Option?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the top-level Option, e.g. Option[T], it's internal data type is T. However, for a udf, it always requires the external data type for its input values. So, when the ScalaUDF receives a null value of type T from the child, it needs to convert it to None instead of simply passing in the null value like other nullable data types.

|if (${eval.isNull}) {
| try {
| $argTerm = $convertersTerm[$i].apply(null);
| } catch (Exception e) {
| // it's not a scala.Option type
| }
|}
|if (!($argTerm instanceof scala.Option)) {
| $argTerm = ${eval.isNull} ? null : $convertersTerm[$i].apply(${eval.value});
|}
""".stripMargin
}
(argTerm, initArg)
}.unzip
Expand Down
28 changes: 26 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,10 @@ class UDFSuite extends QueryTest with SharedSparkSession {
test("case class as generic type of Option") {
val f = (o: Option[TestData]) => o.map(t => t.key * t.value.toInt)
val myUdf = udf(f)
val df = Seq(("data", Some(TestData(50, "2")))).toDF("col1", "col2")
checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
val df1 = Seq(("data", Some(TestData(50, "2")))).toDF("col1", "col2")
checkAnswer(df1.select(myUdf(Column("col2"))), Row(100) :: Nil)
val df2 = Seq(("data", None: Option[TestData])).toDF("col1", "col2")
checkAnswer(df2.select(myUdf(Column("col2"))), Row(null) :: Nil)
}

test("more input fields than expect for case class") {
Expand Down Expand Up @@ -652,4 +654,26 @@ class UDFSuite extends QueryTest with SharedSparkSession {
.select(struct("value", "key").as("col"))
checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil)
}

test("top level Option primitive type") {
val f = (i: Option[Int]) => i.map(_ * 10)
val myUdf = udf(f)
val df = Seq(Some(10), None).toDF("col")
checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Row(null) :: Nil)
}

test("top level Option case class") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already tested in case class as generic type of Option

val f = (i: Option[TestData]) => i.map(t => t.key * t.value.toInt)
val myUdf = udf(f)
val df = Seq(Some(TestData(50, "2")), None).toDF("col")
checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Row(null) :: Nil)
}

test("array Option") {
val f = (i: Array[Option[TestData]]) =>
i.map(_.map(t => t.key * t.value.toInt).getOrElse(0)).sum
val myUdf = udf(f)
val df = Seq(Array(Some(TestData(50, "2")), None)).toDF("col")
checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil)
}
}