Skip to content

Commit a7cbe69

Browse files
AngersZhuuuucloud-fan
authored andcommitted
[SPARK-36753][SQL] ArrayExcept handle duplicated Double.NaN and Float.NaN
### What changes were proposed in this pull request? For query ``` select array_except(array(cast('nan' as double), 1d), array(cast('nan' as double))) ``` This returns [NaN, 1d], but it should return [1d]. This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too. In this pr fix this based on #33955 ### Why are the changes needed? Fix bug ### Does this PR introduce _any_ user-facing change? ArrayExcept won't show handle equal `NaN` value ### How was this patch tested? Added UT Closes #33994 from AngersZhuuuu/SPARK-36753. Authored-by: Angerszhuuuu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent ec26d94 commit a7cbe69

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ import org.apache.spark.unsafe.UTF8StringBuilder
3939
import org.apache.spark.unsafe.array.ByteArrayMethods
4040
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
4141
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
42-
import org.apache.spark.util.collection.OpenHashSet
4342

4443
/**
4544
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
@@ -4109,32 +4108,38 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
41094108
@transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
41104109
if (TypeUtils.typeWithProperEquals(elementType)) {
41114110
(array1, array2) =>
4112-
val hs = new OpenHashSet[Any]
4113-
var notFoundNullElement = true
4111+
val hs = new SQLOpenHashSet[Any]
4112+
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
4113+
val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
4114+
(value: Any) => hs.add(value),
4115+
(valueNaN: Any) => {})
4116+
val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
4117+
(value: Any) =>
4118+
if (!hs.contains(value)) {
4119+
arrayBuffer += value
4120+
hs.add(value)
4121+
},
4122+
(valueNaN: Any) => arrayBuffer += valueNaN)
41144123
var i = 0
41154124
while (i < array2.numElements()) {
41164125
if (array2.isNullAt(i)) {
4117-
notFoundNullElement = false
4126+
hs.addNull()
41184127
} else {
41194128
val elem = array2.get(i, elementType)
4120-
hs.add(elem)
4129+
withArray2NaNCheckFunc(elem)
41214130
}
41224131
i += 1
41234132
}
4124-
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
41254133
i = 0
41264134
while (i < array1.numElements()) {
41274135
if (array1.isNullAt(i)) {
4128-
if (notFoundNullElement) {
4136+
if (!hs.containsNull()) {
41294137
arrayBuffer += null
4130-
notFoundNullElement = false
4138+
hs.addNull()
41314139
}
41324140
} else {
41334141
val elem = array1.get(i, elementType)
4134-
if (!hs.contains(elem)) {
4135-
arrayBuffer += elem
4136-
hs.add(elem)
4137-
}
4142+
withArray1NaNCheckFunc(elem)
41384143
}
41394144
i += 1
41404145
}
@@ -4203,10 +4208,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
42034208
val ptName = CodeGenerator.primitiveTypeName(jt)
42044209

42054210
nullSafeCodeGen(ctx, ev, (array1, array2) => {
4206-
val notFoundNullElement = ctx.freshName("notFoundNullElement")
42074211
val nullElementIndex = ctx.freshName("nullElementIndex")
42084212
val builder = ctx.freshName("builder")
4209-
val openHashSet = classOf[OpenHashSet[_]].getName
4213+
val openHashSet = classOf[SQLOpenHashSet[_]].getName
42104214
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
42114215
val hashSet = ctx.freshName("hashSet")
42124216
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
@@ -4217,7 +4221,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
42174221
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
42184222
s"""
42194223
|if ($array2.isNullAt($i)) {
4220-
| $notFoundNullElement = false;
4224+
| $hashSet.addNull();
42214225
|} else {
42224226
| $body
42234227
|}
@@ -4235,18 +4239,18 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
42354239
}
42364240

42374241
val writeArray2ToHashSet = withArray2NullCheck(
4238-
s"""
4239-
|$jt $value = ${genGetValue(array2, i)};
4240-
|$hashSet.add$hsPostFix($hsValueCast$value);
4241-
""".stripMargin)
4242+
s"$jt $value = ${genGetValue(array2, i)};" +
4243+
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
4244+
s"$hashSet.add$hsPostFix($hsValueCast$value);",
4245+
(valueNaN: Any) => ""))
42424246

42434247
def withArray1NullAssignment(body: String) =
42444248
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
42454249
s"""
42464250
|if ($array1.isNullAt($i)) {
4247-
| if ($notFoundNullElement) {
4251+
| if (!$hashSet.containsNull()) {
4252+
| $hashSet.addNull();
42484253
| $nullElementIndex = $size;
4249-
| $notFoundNullElement = false;
42504254
| $size++;
42514255
| $builder.$$plus$$eq($nullValueHolder);
42524256
| }
@@ -4258,22 +4262,29 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
42584262
body
42594263
}
42604264

4261-
val processArray1 = withArray1NullAssignment(
4265+
val body =
42624266
s"""
4263-
|$jt $value = ${genGetValue(array1, i)};
42644267
|if (!$hashSet.contains($hsValueCast$value)) {
42654268
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
42664269
| break;
42674270
| }
42684271
| $hashSet.add$hsPostFix($hsValueCast$value);
42694272
| $builder.$$plus$$eq($value);
42704273
|}
4271-
""".stripMargin)
4274+
""".stripMargin
4275+
4276+
val processArray1 = withArray1NullAssignment(
4277+
s"$jt $value = ${genGetValue(array1, i)};" +
4278+
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
4279+
(valueNaN: String) =>
4280+
s"""
4281+
|$size++;
4282+
|$builder.$$plus$$eq($valueNaN);
4283+
""".stripMargin))
42724284

42734285
// Only need to track null element index when array1's element is nullable.
42744286
val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
42754287
s"""
4276-
|boolean $notFoundNullElement = true;
42774288
|int $nullElementIndex = -1;
42784289
""".stripMargin
42794290
} else {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2327,6 +2327,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
23272327
Seq(Float.NaN, null, 1f))
23282328
}
23292329

2330+
test("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan") {
2331+
checkEvaluation(ArrayExcept(
2332+
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))),
2333+
Seq(1d))
2334+
checkEvaluation(ArrayExcept(
2335+
Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)),
2336+
Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType))),
2337+
Seq(1d))
2338+
checkEvaluation(ArrayExcept(
2339+
Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN))),
2340+
Seq(1f))
2341+
checkEvaluation(ArrayExcept(
2342+
Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)),
2343+
Literal.create(Seq(Float.NaN, null), ArrayType(FloatType))),
2344+
Seq(1f))
2345+
}
2346+
23302347
test("SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and Float.Nan") {
23312348
checkEvaluation(ArrayIntersect(
23322349
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN, 1d, 2d))),

0 commit comments

Comments
 (0)