Skip to content
Next Next commit
[SPARK-36792][SQL] InSet should handle NaN
  • Loading branch information
AngersZhuuuu committed Sep 17, 2021
commit a74a3104e66dcfbfe7561e524a2b75555addfbba
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,11 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}

@transient private[this] lazy val hasNull: Boolean = hset.contains(null)
private[this] val isNaN: Any => Boolean = child.dataType match {
case DoubleType => (value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
case FloatType => (value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
case _ => (_: Any) => false
}

override def nullable: Boolean = child.nullable || hasNull

Expand All @@ -562,6 +567,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
protected override def nullSafeEval(value: Any): Any = {
if (set.contains(value)) {
true
} else if (isNaN(value)) {
set.exists(isNaN(_))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a hasNaN variable to avoid repeated computing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about current?

} else if (hasNull) {
null
} else {
Expand Down Expand Up @@ -593,15 +600,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
private def genCodeWithSet(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val setTerm = ctx.addReferenceObj("set", set)
val i = ctx.freshName("i")
val elem = ctx.freshName("elem")
val jt = CodeGenerator.javaType(child.dataType)

val setIsNull = if (hasNull) {
s"${ev.isNull} = !${ev.value};"
} else {
""
}
s"""
|${ev.value} = $setTerm.contains($c);
|$setIsNull
""".stripMargin
val ret = child.dataType match {
case DoubleType => Some((v: String) => s"java.lang.Double.isNaN((double)$v)")
case FloatType => Some((v: String) => s"java.lang.Float.isNaN((float)$v)")
case _ => None
}
ret.map { isNaN =>
s"""
|if ($setTerm.contains($c)) {
| ${ev.value} = true;
|} else if (${isNaN(c)}) {
| for (int $i = 0; $i < $setTerm.size(); $i++) {
| $jt $elem = $setTerm.elems()[$i];
| if (${isNaN(s"$elem")}) {
| ${ev.value} = true;
| break;
| }
| }
|}
|$setIsNull
|""".stripMargin
}.getOrElse(
s"""
|${ev.value} = $setTerm.contains($c);
|$setIsNull
""".stripMargin)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,4 +644,18 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
checkExpr(GreaterThan, 0.0, -0.0, false)
}

test("Inset NaN") {
checkInAndInSet(In(Literal(Double.NaN), Seq(Literal(Double.NaN), Literal(2d))), true)
checkInAndInSet(In(Literal.create(null, DoubleType),
Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, DoubleType))), null)
checkInAndInSet(In(Literal.create(null, DoubleType),
Seq(Literal(Double.NaN), Literal(2d))), null)
checkInAndInSet(In(Literal(3d),
Seq(Literal(Double.NaN), Literal(2d))), false)
checkInAndInSet(In(Literal(3d),
Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, DoubleType))), null)
checkInAndInSet(In(Literal(Double.NaN),
Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, DoubleType))), true)
}
}