Skip to content

Commit eec3660

Browse files
Davies Liudavies
authored andcommitted
[SPARK-12258] [SQL] passing null into ScalaUDF (follow-up)
This is a follow-up PR for #10259 Author: Davies Liu <[email protected]> Closes #10266 from davies/null_udf2. (cherry picked from commit c119a34) Signed-off-by: Davies Liu <[email protected]>
1 parent 250249e commit eec3660

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,24 +1029,27 @@ case class ScalaUDF(
10291029
// such as IntegerType, its javaType is `int` and the returned type of user-defined
10301030
// function is Object. Trying to convert an Object to `int` will cause casting exception.
10311031
val evalCode = evals.map(_.code).mkString
1032-
val funcArguments = converterTerms.zipWithIndex.map {
1033-
case (converter, i) =>
1034-
val eval = evals(i)
1035-
val dt = children(i).dataType
1036-
s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})"
1037-
}.mkString(",")
1038-
val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " +
1039-
s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" +
1040-
s".apply($funcTerm.apply($funcArguments));"
1032+
val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) =>
1033+
val eval = evals(i)
1034+
val argTerm = ctx.freshName("arg")
1035+
val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});"
1036+
(convert, argTerm)
1037+
}.unzip
10411038

1042-
evalCode + s"""
1043-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
1044-
Boolean ${ev.isNull};
1039+
val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " +
1040+
s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" +
1041+
s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));"
10451042

1043+
s"""
1044+
$evalCode
1045+
${converters.mkString("\n")}
10461046
$callFunc
10471047

1048-
${ev.value} = $resultTerm;
1049-
${ev.isNull} = $resultTerm == null;
1048+
boolean ${ev.isNull} = $resultTerm == null;
1049+
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
1050+
if (!${ev.isNull}) {
1051+
${ev.value} = $resultTerm;
1052+
}
10501053
"""
10511054
}
10521055

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,9 +1137,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
11371137

11381138
// passing null into the UDF that could handle it
11391139
val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
1140-
(i: java.lang.Integer) => if (i == null) -10 else i * 2
1140+
(i: java.lang.Integer) => if (i == null) -10 else null
11411141
}
1142-
checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil)
1142+
checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)
1143+
1144+
sqlContext.udf.register("boxedUDF",
1145+
(i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer)
1146+
checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil)
11431147

11441148
val primitiveUDF = udf((i: Int) => i * 2)
11451149
checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)

0 commit comments

Comments
 (0)