-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31826][SQL] Support composed type of case class for typed Scala UDF #28645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
25c881b
4546f35
c1a2d1c
d13f1c7
bb26320
5e0f445
8576d28
1afbdf5
86035fa
21e8aaf
2527c69
26fc42b
9e986f0
7568e8c
6b384b4
1c82558
21ae72b
bdbd45b
4db6401
3e97fa5
e6bb55d
f29a62a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,8 +105,8 @@ case class ScalaUDF( | |
| private def scalaConverter(i: Int, dataType: DataType): Any => Any = { | ||
| if (inputEncoders.isEmpty || // for untyped Scala UDF | ||
|
||
| inputEncoders(i).isEmpty || // for types aren't supported by encoder, e.g. Any | ||
| isPrimitive(dataType) || | ||
| dataType.isInstanceOf[UserDefinedType[_]]) { | ||
| inputPrimitives(i) || // inputPrimitives is not empty when inputEncoders is not empty | ||
| dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) { | ||
| createToScalaConverter(dataType) | ||
| } else { | ||
| val enc = inputEncoders(i).get | ||
|
|
@@ -1065,14 +1065,32 @@ case class ScalaUDF( | |
| val (funcArgs, initArgs) = evals.zipWithIndex.zip(children.map(_.dataType)).map { | ||
| case ((eval, i), dt) => | ||
| val argTerm = ctx.freshName("arg") | ||
| val initArg = if (isPrimitive(dt)) { | ||
| // Check `inputPrimitives` when it's not empty in order to figure out the Option | ||
| // type as non primitive type, e.g., Option[Int]. Fall back to `isPrimitive` when | ||
| // `inputPrimitives` is empty for other cases, e.g., Java UDF, untyped Scala UDF | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so untyped Scala UDF doesn't support
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea. We require the encoder to support |
||
| val primitive = (inputPrimitives.isEmpty && isPrimitive(dt)) || | ||
| (inputPrimitives.nonEmpty && inputPrimitives(i)) | ||
| val initArg = if (primitive) { | ||
| val convertedTerm = ctx.freshName("conv") | ||
| s""" | ||
| |${CodeGenerator.boxedType(dt)} $convertedTerm = ${eval.value}; | ||
| |Object $argTerm = ${eval.isNull} ? null : $convertedTerm; | ||
| """.stripMargin | ||
| } else { | ||
| s"Object $argTerm = ${eval.isNull} ? null : $convertersTerm[$i].apply(${eval.value});" | ||
| s""" | ||
| |Object $argTerm = null; | ||
| |// handle the top level Option type specifically | ||
|
||
| |if (${eval.isNull}) { | ||
| | try { | ||
| | $argTerm = $convertersTerm[$i].apply(null); | ||
| | } catch (Exception e) { | ||
| | // it's not a scala.Option type | ||
| | } | ||
| |} | ||
| |if (!($argTerm instanceof scala.Option)) { | ||
| | $argTerm = ${eval.isNull} ? null : $convertersTerm[$i].apply(${eval.value}); | ||
| |} | ||
| """.stripMargin | ||
| } | ||
| (argTerm, initArg) | ||
| }.unzip | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -621,8 +621,10 @@ class UDFSuite extends QueryTest with SharedSparkSession { | |
| test("case class as generic type of Option") { | ||
| val f = (o: Option[TestData]) => o.map(t => t.key * t.value.toInt) | ||
| val myUdf = udf(f) | ||
| val df = Seq(("data", Some(TestData(50, "2")))).toDF("col1", "col2") | ||
| checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil) | ||
| val df1 = Seq(("data", Some(TestData(50, "2")))).toDF("col1", "col2") | ||
| checkAnswer(df1.select(myUdf(Column("col2"))), Row(100) :: Nil) | ||
| val df2 = Seq(("data", None: Option[TestData])).toDF("col1", "col2") | ||
| checkAnswer(df2.select(myUdf(Column("col2"))), Row(null) :: Nil) | ||
| } | ||
|
|
||
| test("more input fields than expect for case class") { | ||
|
|
@@ -652,4 +654,26 @@ class UDFSuite extends QueryTest with SharedSparkSession { | |
| .select(struct("value", "key").as("col")) | ||
| checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil) | ||
| } | ||
|
|
||
| test("top level Option primitive type") { | ||
| val f = (i: Option[Int]) => i.map(_ * 10) | ||
| val myUdf = udf(f) | ||
| val df = Seq(Some(10), None).toDF("col") | ||
| checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Row(null) :: Nil) | ||
| } | ||
|
|
||
| test("top level Option case class") { | ||
|
||
| val f = (i: Option[TestData]) => i.map(t => t.key * t.value.toInt) | ||
| val myUdf = udf(f) | ||
| val df = Seq(Some(TestData(50, "2")), None).toDF("col") | ||
| checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Row(null) :: Nil) | ||
| } | ||
|
|
||
| test("array Option") { | ||
| val f = (i: Array[Option[TestData]]) => | ||
| i.map(_.map(t => t.key * t.value.toInt).getOrElse(0)).sum | ||
| val myUdf = udf(f) | ||
| val df = Seq(Array(Some(TestData(50, "2")), None)).toDF("col") | ||
| checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we keep the name unchanged? If we keep using
createToScalaConverter, many diff can be avoided.