Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix ScalaUDF output nullability
  • Loading branch information
cloud-fan committed Jun 25, 2024
commit cd685ee264b9939c8f173c4506f0d7c6902c2345
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
new ResolveHints.RemoveAllHints),
Batch("Nondeterministic", Once,
PullOutNondeterministic),
Batch("ScalaUDF", Once,
HandleNullInputsForUDF),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule adds null handling for ScalaUDF, so even if the ScalaUDF's nullable is false, the final expression (if-else) can also return null, so we need to run UpdateAttributeNullability after it.

Copy link
Contributor

@eejbyfeldt eejbyfeldt Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we also need to run it before? Or how could we correctly handle null inputs to a udf if we do not have the correct nullability for the inputs?

Seems to me like one should be able to write a test case involving an outer join that shows that the proposed code will not handle such nullable input correctly.

Batch("UpdateNullability", Once,
UpdateAttributeNullability),
Batch("UDF", Once,
HandleNullInputsForUDF,
ResolveEncodersInUDF),
Batch("Subquery", Once,
UpdateOuterReferences),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1796,4 +1796,15 @@ class AnalysisSuite extends AnalysisTest with Matchers {
assert(refs.head.resolved)
assert(refs.head.isStreaming)
}

test("SPARK-47927: ScalaUDF output nullability") {
val udf = ScalaUDF(
function = (i: Int) => i + 1,
dataType = IntegerType,
children = $"a" :: Nil,
nullable = false,
inputEncoders = Seq(Some(ExpressionEncoder[Int]().resolveAndBind())))
val plan = testRelation.select(udf.as("u")).select($"u").analyze
assert(plan.output.head.nullable)
}
}