Skip to content

Commit 8b8ea60

Browse files
Emil Ejbyfeldtcloud-fan
authored andcommitted
[SPARK-47927][SQL] Fix nullability attribute in UDF decoder
### What changes were proposed in this pull request? This PR fixes a correctness issue by moving the batch that resolves udf decoders to after the `UpdateNullability` batch. This means we now derive a decoder with the updated attributes which fixes a correctness issue. I think the issue has existed since apache#28645 when udf support case class arguments was added. So therefore this issue should be present in all currently supported versions. ### Why are the changes needed? Currently the following code ``` scala> val ds1 = Seq(1).toDS() | val ds2 = Seq[Int]().toDS() | val f = udf[Tuple1[Option[Int]],Tuple1[Option[Int]]](identity) | ds1.join(ds2, ds1("value") === ds2("value"), "left_outer").select(f(struct(ds2("value")))).collect() val ds1: org.apache.spark.sql.Dataset[Int] = [value: int] val ds2: org.apache.spark.sql.Dataset[Int] = [value: int] val f: org.apache.spark.sql.expressions.UserDefinedFunction = SparkUserDefinedFunction($Lambda$2481/0x00007f7f50961f086b1a2c9f,StructType(StructField(_1,IntegerType,true)),List(Some(class[_1[0]: int])),Some(class[_1[0]: int]),None,true,true) val res0: Array[org.apache.spark.sql.Row] = Array([[0]]) ``` results in an row containing `0` this is incorrect as the value should be `null`. Removing the udf call ``` scala> ds1.join(ds2, ds1("value") === ds2("value"), "left_outer").select(struct(ds2("value"))).collect() val res1: Array[org.apache.spark.sql.Row] = Array([[null]]) ``` gives the correct value. ### Does this PR introduce _any_ user-facing change? Yes, fixes a correctness issue when using ScalaUDFs. ### How was this patch tested? Existing and new unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#46156 from eejbyfeldt/SPARK-47927. Authored-by: Emil Ejbyfeldt <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 76ce6b0 commit 8b8ea60

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
339339
new ResolveHints.RemoveAllHints),
340340
Batch("Nondeterministic", Once,
341341
PullOutNondeterministic),
342+
Batch("UpdateNullability", Once,
343+
UpdateAttributeNullability),
342344
Batch("UDF", Once,
343345
HandleNullInputsForUDF,
344346
ResolveEncodersInUDF),
345-
Batch("UpdateNullability", Once,
346-
UpdateAttributeNullability),
347347
Batch("Subquery", Once,
348348
UpdateOuterReferences),
349349
Batch("Cleanup", fixedPoint,

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,4 +1183,15 @@ class UDFSuite extends QueryTest with SharedSparkSession {
11831183
df10.select(zip_with(col("array1"), col("array2"), (b1, b2) => reverseThenConcat2(b1, b2)))
11841184
checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
11851185
}
1186+
1187+
test("SPARK-47927: Correctly pass null values derived from join to UDF") {
1188+
val f = udf[Tuple1[Option[Int]], Tuple1[Option[Int]]](identity)
1189+
val ds1 = Seq(1).toDS()
1190+
val ds2 = Seq[Int]().toDS()
1191+
1192+
checkAnswer(
1193+
ds1.join(ds2, ds1("value") === ds2("value"), "left_outer")
1194+
.select(f(struct(ds2("value").as("_1")))),
1195+
Row(Row(null)))
1196+
}
11861197
}

0 commit comments

Comments
 (0)