Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
simplify CaseWhen when one clause is null and another is boolean
  • Loading branch information
wangyum committed Dec 23, 2020
commit c78882a1dd06c90c5971b6ac86a0f12561b7ca4f
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,19 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l)
case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l)

case CaseWhen(Seq((cond, l @ Literal(null, _))), Some(FalseLiteral))
if !cond.nullable =>
And(cond, l)
case CaseWhen(Seq((cond, l @ Literal(null, _))), Some(TrueLiteral))
if !cond.nullable =>
Or(Not(cond), l)
case CaseWhen(Seq((cond, FalseLiteral)), elseOpt @ (Some(Literal(null, BooleanType)) | None))
if !cond.nullable =>
And(Not(cond), elseOpt.getOrElse(Literal(null, BooleanType)))
case CaseWhen(Seq((cond, TrueLiteral)), elseOpt @ (Some(Literal(null, BooleanType)) | None))
if !cond.nullable =>
Or(cond, elseOpt.getOrElse(Literal(null, BooleanType)))

case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,52 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
LessThanOrEqual(Rand(0), UnresolvedAttribute("a")))
}

test("SPARK-33884: simplify CaseWhen when one clause is null and another is boolean") {
val p = IsNull('a)
val nullLiteral = Literal(null, BooleanType)
assertEquivalent(CaseWhen(Seq((p, nullLiteral)), FalseLiteral), And(p, nullLiteral))
assertEquivalent(CaseWhen(Seq((p, nullLiteral)), TrueLiteral), Or(IsNotNull('a), nullLiteral))
assertEquivalent(CaseWhen(Seq((p, FalseLiteral)), nullLiteral), And(IsNotNull('a), nullLiteral))
assertEquivalent(CaseWhen(Seq((p, FalseLiteral)), None), And(IsNotNull('a), nullLiteral))
assertEquivalent(CaseWhen(Seq((p, TrueLiteral)), nullLiteral), Or(p, nullLiteral))
assertEquivalent(CaseWhen(Seq((p, TrueLiteral)), None), Or(p, nullLiteral))

// the rule should not apply to nullable predicate
Seq(TrueLiteral, FalseLiteral).foreach { b =>
assertEquivalent(CaseWhen(Seq((GreaterThan('a, 42), nullLiteral)), b),
CaseWhen(Seq((GreaterThan('a, 42), nullLiteral)), b))
assertEquivalent(CaseWhen(Seq((GreaterThan('a, 42), b)), nullLiteral),
CaseWhen(Seq((GreaterThan('a, 42), b)), nullLiteral))
assertEquivalent(CaseWhen(Seq((GreaterThan('a, 42), b)), None),
CaseWhen(Seq((GreaterThan('a, 42), b)), None))
}

// check evaluation also
Seq(TrueLiteral, FalseLiteral).foreach { b =>
checkEvaluation(CaseWhen(Seq((b, nullLiteral)), FalseLiteral),
And(b, nullLiteral).eval(EmptyRow))
checkEvaluation(CaseWhen(Seq((b, nullLiteral)), TrueLiteral),
Or(Not(b), nullLiteral).eval(EmptyRow))
checkEvaluation(CaseWhen(Seq((b, FalseLiteral)), nullLiteral),
And(Not(b), nullLiteral).eval(EmptyRow))
checkEvaluation(CaseWhen(Seq((b, FalseLiteral)), None),
And(Not(b), nullLiteral).eval(EmptyRow))
checkEvaluation(CaseWhen(Seq((b, TrueLiteral)), nullLiteral),
Or(b, nullLiteral).eval(EmptyRow))
checkEvaluation(CaseWhen(Seq((b, TrueLiteral)), None),
Or(b, nullLiteral).eval(EmptyRow))
}

// should have no effect on expressions with nullable if condition
assert((Factorial(5) > 100L).nullable)
Seq(TrueLiteral, FalseLiteral).foreach { b =>
checkEvaluation(CaseWhen(Seq((Factorial(5) > 100L, nullLiteral)), b),
CaseWhen(Seq((Factorial(5) > 100L, nullLiteral)), b).eval(EmptyRow))
checkEvaluation(CaseWhen(Seq((Factorial(5) > 100L, b)), nullLiteral),
CaseWhen(Seq((Factorial(5) > 100L, b)), nullLiteral).eval(EmptyRow))
checkEvaluation(CaseWhen(Seq((Factorial(5) > 100L, b)), None),
CaseWhen(Seq((Factorial(5) > 100L, b)), None).eval(EmptyRow))
}
}
}