diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index a7fc87d8b65d..2b7389112eac 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1697,7 +1697,7 @@ class SparkConnectPlanner( inputEncoders = udfPacket.inputEncoders.map(e => Try(ExpressionEncoder(e)).toOption), outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)), udfName = Option(fun.getFunctionName), - nullable = udf.getNullable, + outputNullable = udf.getNullable, udfDeterministic = fun.getDeterministic) } } @@ -2071,7 +2071,7 @@ class SparkConnectPlanner( inputEncoders = f.inputEncoders, outputEncoder = f.outputEncoder, udfName = f.name, - nullable = f.nullable, + outputNullable = f.nullable, udfDeterministic = f.deterministic) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index bba3d4b1a806..c59f4cf91c39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -52,10 +52,14 @@ case class ScalaUDF( inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, outputEncoder: Option[ExpressionEncoder[_]] = None, udfName: Option[String] = None, - nullable: Boolean = true, + outputNullable: Boolean = true, udfDeterministic: Boolean = true) extends Expression with NonSQLExpression with UserDefinedExpression { + // The Rule HandleNullInputsForUDF makes the UDF null-propagatable for primitive nullable inputs. + override lazy val nullable: Boolean = outputNullable || inputPrimitives.zip(children) + .exists {case (isPrimitive, child) => isPrimitive && child.nullable} + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) // `ScalaUDF` uses `ExpressionEncoder` to convert the function result to Catalyst internal format. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a65fbef1a373..e1c774151b62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1796,4 +1796,15 @@ class AnalysisSuite extends AnalysisTest with Matchers { assert(refs.head.resolved) assert(refs.head.isStreaming) } + + test("SPARK-47927: ScalaUDF output nullability") { + val udf = ScalaUDF( + function = (i: Int) => i + 1, + dataType = IntegerType, + children = $"a" :: Nil, + outputNullable = false, + inputEncoders = Seq(Some(ExpressionEncoder[Int]().resolveAndBind()))) + val plan = testRelation.select(udf.as("u")).select($"u").analyze + assert(plan.output.head.nullable) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index a75384fb0f4e..eb298d9d34d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -108,7 +108,7 @@ private[spark] case class SparkUserDefinedFunction( inputEncoders, outputEncoder, udfName = name, - nullable = nullable, + outputNullable = nullable, udfDeterministic = deterministic) }