Skip to content

Commit 4636fe3

Browse files
author
Davies Liu
committed
passing null into ScalaUDF
1 parent 442a771 commit 4636fe3

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,8 +1029,11 @@ case class ScalaUDF(
10291029
// such as IntegerType, its javaType is `int` and the returned type of user-defined
10301030
// function is Object. Trying to convert an Object to `int` will cause casting exception.
10311031
val evalCode = evals.map(_.code).mkString
1032-
val funcArguments = converterTerms.zip(evals).map {
1033-
case (converter, eval) => s"$converter.apply(${eval.value})"
1032+
val funcArguments = converterTerms.zipWithIndex.map {
1033+
case (converter, i) =>
1034+
val eval = evals(i)
1035+
val dt = children(i).dataType
1036+
s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})"
10341037
}.mkString(",")
10351038
val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " +
10361039
s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" +

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,14 +1138,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
11381138
}
11391139

11401140
test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
1141-
val df = Seq(
1141+
val df = sparkContext.parallelize(Seq(
11421142
new java.lang.Integer(22) -> "John",
1143-
null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name")
1143+
null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name")
11441144

1145+
// passing null into the UDF that could handle it
11451146
val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
1146-
(i: java.lang.Integer) => if (i == null) null else i * 2
1147+
(i: java.lang.Integer) => if (i == null) -10 else i * 2
11471148
}
1148-
checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil)
1149+
checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil)
11491150

11501151
val primitiveUDF = udf((i: Int) => i * 2)
11511152
checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)

0 commit comments

Comments
 (0)