diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cd7aeb7cd4ac..ba6764444bdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -344,10 +344,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor new ResolveHints.RemoveAllHints), Batch("Nondeterministic", Once, PullOutNondeterministic), - Batch("UpdateNullability", Once, + Batch("ScalaUDF Null Handling", fixedPoint, + // `HandleNullInputsForUDF` may wrap the `ScalaUDF` with `If` expression to return null for + // null inputs, so the result can be null even if `ScalaUDF#nullable` is false. We need to + // run `UpdateAttributeNullability` to update nullability of the UDF output attribute in + // downstream operators. After updating attribute nullability, `ScalaUDF`s in downstream + // operators may need null handling as well, so we should run these two rules repeatedly. + HandleNullInputsForUDF, UpdateAttributeNullability), Batch("UDF", Once, - HandleNullInputsForUDF, ResolveEncodersInUDF), Batch("Subquery", Once, UpdateOuterReferences), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a65fbef1a373..62856a96f7ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1796,4 +1796,15 @@ class AnalysisSuite extends AnalysisTest with Matchers { assert(refs.head.resolved) assert(refs.head.isStreaming) } + + test("SPARK-47927: ScalaUDF output nullability") { + val udf = ScalaUDF( + function = (i: Int) => i + 1, + dataType = IntegerType, + children = $"a" :: Nil, + nullable = false, + inputEncoders = Seq(Some(ExpressionEncoder[Int]().resolveAndBind()))) + val plan = testRelation.select(udf.as("u")).select($"u").analyze + assert(plan.output.head.nullable) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 32ad5a94984b..7e940252430f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -1211,4 +1211,11 @@ class UDFSuite extends QueryTest with SharedSparkSession { ) } } + + test("SPARK-47927: ScalaUDF null handling") { + val f = udf[Int, Int](_ + 1) + val df = Seq(Some(1), None).toDF("c") + .select(f($"c").as("f"), f($"f")) + checkAnswer(df, Seq(Row(2, 3), Row(null, null))) + } }