Skip to content

Commit 33768f6

Browse files
Emil Ejbyfeldtcloud-fan
authored andcommitted
[SPARK-47927][SQL] Fix nullability attribute in UDF decoder
This PR fixes a correctness issue by moving the batch that resolves udf decoders to after the `UpdateNullability` batch. This means we now derive a decoder with the updated attributes which fixes a correctness issue. I think the issue has existed since #28645 when udf support case class arguments was added. So therefore this issue should be present in all currently supported versions. Currently the following code ``` scala> val ds1 = Seq(1).toDS() | val ds2 = Seq[Int]().toDS() | val f = udf[Tuple1[Option[Int]],Tuple1[Option[Int]]](identity) | ds1.join(ds2, ds1("value") === ds2("value"), "left_outer").select(f(struct(ds2("value")))).collect() val ds1: org.apache.spark.sql.Dataset[Int] = [value: int] val ds2: org.apache.spark.sql.Dataset[Int] = [value: int] val f: org.apache.spark.sql.expressions.UserDefinedFunction = SparkUserDefinedFunction($Lambda$2481/0x00007f7f50961f086b1a2c9f,StructType(StructField(_1,IntegerType,true)),List(Some(class[_1[0]: int])),Some(class[_1[0]: int]),None,true,true) val res0: Array[org.apache.spark.sql.Row] = Array([[0]]) ``` results in an row containing `0` this is incorrect as the value should be `null`. Removing the udf call ``` scala> ds1.join(ds2, ds1("value") === ds2("value"), "left_outer").select(struct(ds2("value"))).collect() val res1: Array[org.apache.spark.sql.Row] = Array([[null]]) ``` gives the correct value. Yes, fixes a correctness issue when using ScalaUDFs. Existing and new unit tests. No. Closes #46156 from eejbyfeldt/SPARK-47927. Authored-by: Emil Ejbyfeldt <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 8b8ea60) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 7a573b9 commit 33768f6

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,11 +338,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
338338
new ResolveHints.RemoveAllHints),
339339
Batch("Nondeterministic", Once,
340340
PullOutNondeterministic),
341+
Batch("UpdateNullability", Once,
342+
UpdateAttributeNullability),
341343
Batch("UDF", Once,
342344
HandleNullInputsForUDF,
343345
ResolveEncodersInUDF),
344-
Batch("UpdateNullability", Once,
345-
UpdateAttributeNullability),
346346
Batch("Subquery", Once,
347347
UpdateOuterReferences),
348348
Batch("Cleanup", fixedPoint,

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,4 +1067,15 @@ class UDFSuite extends QueryTest with SharedSparkSession {
10671067
.lookupFunctionInfo(FunctionIdentifier("dummyUDF"))
10681068
assert(expressionInfo.getClassName.contains("org.apache.spark.sql.UDFRegistration$$Lambda"))
10691069
}
1070+
1071+
test("SPARK-47927: Correctly pass null values derived from join to UDF") {
1072+
val f = udf[Tuple1[Option[Int]], Tuple1[Option[Int]]](identity)
1073+
val ds1 = Seq(1).toDS()
1074+
val ds2 = Seq[Int]().toDS()
1075+
1076+
checkAnswer(
1077+
ds1.join(ds2, ds1("value") === ds2("value"), "left_outer")
1078+
.select(f(struct(ds2("value").as("_1")))),
1079+
Row(Row(null)))
1080+
}
10701081
}

0 commit comments

Comments
 (0)