Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1b39da5
[SPARK-36742][SQL] ArrayDistinct handle duplicated Double.NaN and Flo…
AngersZhuuuu Sep 14, 2021
0ac9924
Update collectionOperations.scala
AngersZhuuuu Sep 14, 2021
4189d71
[SPARK-36754][SQL] ArrayIntersect handle duplicated Double.NaN and Fl…
AngersZhuuuu Sep 14, 2021
63763df
Update CollectionExpressionsSuite.scala
AngersZhuuuu Sep 14, 2021
55f230d
Merge branch 'master' into SPARK-36741
AngersZhuuuu Sep 15, 2021
7bc972f
Update collectionOperations.scala
AngersZhuuuu Sep 15, 2021
f5c5452
Merge branch 'master' into SPARK-36741
AngersZhuuuu Sep 15, 2021
2fe60cd
Merge branch 'master' into SPARK-36754
AngersZhuuuu Sep 15, 2021
64afef9
Update collectionOperations.scala
AngersZhuuuu Sep 15, 2021
2478eb4
refactor
AngersZhuuuu Sep 16, 2021
202cf4e
follow comment
AngersZhuuuu Sep 16, 2021
5546a55
Update SQLOpenHashSet.scala
AngersZhuuuu Sep 16, 2021
72870f6
update
AngersZhuuuu Sep 16, 2021
389c9fd
follow comment
AngersZhuuuu Sep 16, 2021
d8e80fc
Update collectionOperations.scala
AngersZhuuuu Sep 17, 2021
d0a914c
update
AngersZhuuuu Sep 17, 2021
6194de6
Update SQLOpenHashSet.scala
AngersZhuuuu Sep 17, 2021
df380c3
Merge branch 'SPARK-36741' into SPARK-36754
AngersZhuuuu Sep 17, 2021
217934b
update
AngersZhuuuu Sep 17, 2021
f97b838
Update collectionOperations.scala
AngersZhuuuu Sep 17, 2021
ac52229
Update collectionOperations.scala
AngersZhuuuu Sep 17, 2021
a9e6205
Update collectionOperations.scala
AngersZhuuuu Sep 17, 2021
85f9d9d
Merge branch 'master' into SPARK-36754
AngersZhuuuu Sep 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3843,33 +3843,42 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) =>
if (array1.numElements() != 0 && array2.numElements() != 0) {
val hs = new OpenHashSet[Any]
val hsResult = new OpenHashSet[Any]
var foundNullElement = false
val hs = new SQLOpenHashSet[Any]
val hsResult = new SQLOpenHashSet[Any]
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) => hs.add(value),
(valueNaN: Any) => {} )
val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hsResult,
(value: Any) =>
if (hs.contains(value) && !hsResult.contains(value)) {
arrayBuffer += value
hsResult.add(value)
},
(valueNaN: Any) =>
if (hs.containsNaN()) {
arrayBuffer += valueNaN
})
var i = 0
while (i < array2.numElements()) {
if (array2.isNullAt(i)) {
foundNullElement = true
hs.addNull()
} else {
val elem = array2.get(i, elementType)
hs.add(elem)
withArray2NaNCheckFunc(elem)
}
i += 1
}
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
i = 0
while (i < array1.numElements()) {
if (array1.isNullAt(i)) {
if (foundNullElement) {
if (hs.containsNull() && !hsResult.containsNull()) {
arrayBuffer += null
foundNullElement = false
hsResult.addNull()
}
} else {
val elem = array1.get(i, elementType)
if (hs.contains(elem) && !hsResult.contains(elem)) {
arrayBuffer += elem
hsResult.add(elem)
}
withArray1NaNCheckFunc(elem)
}
i += 1
}
Expand Down Expand Up @@ -3944,10 +3953,9 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
val ptName = CodeGenerator.primitiveTypeName(jt)

nullSafeCodeGen(ctx, ev, (array1, array2) => {
val foundNullElement = ctx.freshName("foundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
val openHashSet = classOf[OpenHashSet[_]].getName
val openHashSet = classOf[SQLOpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val hashSetResult = ctx.freshName("hashSetResult")
Expand All @@ -3959,7 +3967,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array2.isNullAt($i)) {
| $foundNullElement = true;
| $hashSet.addNull();
|} else {
| $body
|}
Expand All @@ -3977,19 +3985,18 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
}

val writeArray2ToHashSet = withArray2NullCheck(
s"""
|$jt $value = ${genGetValue(array2, i)};
|$hashSet.add$hsPostFix($hsValueCast$value);
""".stripMargin)
s"$jt $value = ${genGetValue(array2, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
s"$hashSet.add$hsPostFix($hsValueCast$value);", (valueNaN: String) => ""))

def withArray1NullAssignment(body: String) =
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
if (right.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array1.isNullAt($i)) {
| if ($foundNullElement) {
| if ($hashSet.containsNull() && !$hashSetResult.containsNull()) {
| $nullElementIndex = $size;
| $foundNullElement = false;
| $hashSetResult.addNull();
| $size++;
| $builder.$$plus$$eq($nullValueHolder);
| }
Expand All @@ -4008,9 +4015,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
body
}

val processArray1 = withArray1NullAssignment(
val body =
s"""
|$jt $value = ${genGetValue(array1, i)};
|if ($hashSet.contains($hsValueCast$value) &&
| !$hashSetResult.contains($hsValueCast$value)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
Expand All @@ -4019,12 +4025,22 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
| $hashSetResult.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
|}
""".stripMargin)
""".stripMargin

val processArray1 = withArray1NullAssignment(
s"$jt $value = ${genGetValue(array1, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSetResult, body,
(valueNaN: Any) =>
s"""
|if ($hashSet.containsNaN()) {
| ++$size;
| $builder.$$plus$$eq($valueNaN);
|}
""".stripMargin))

// Only need to track null element index when result array's element is nullable.
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|boolean $foundNullElement = false;
|int $nullElementIndex = -1;
""".stripMargin
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2327,6 +2327,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(Float.NaN, null, 1f))
}

test("SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayIntersect(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN, 1d, 2d))),
Seq(Double.NaN, 1d))
checkEvaluation(ArrayIntersect(
Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)),
Literal.create(Seq(null, Double.NaN, null), ArrayType(DoubleType))),
Seq(null, Double.NaN))
checkEvaluation(ArrayIntersect(
Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN, 1f, 2f))),
Seq(Float.NaN, 1f))
checkEvaluation(ArrayIntersect(
Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)),
Literal.create(Seq(null, Float.NaN, null), ArrayType(FloatType))),
Seq(null, Float.NaN))
}

test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayDistinct(
Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), ArrayType(DoubleType))),
Expand Down