Skip to content
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4089,32 +4089,38 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
@transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) =>
val hs = new OpenHashSet[Any]
var notFoundNullElement = true
val hs = new SQLOpenHashSet[Any]
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) => hs.add(value),
(valueNaN: Any) => {})
val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) =>
if (!hs.contains(value)) {
arrayBuffer += value
hs.add(value)
},
(valueNaN: Any) => arrayBuffer += valueNaN)
var i = 0
while (i < array2.numElements()) {
if (array2.isNullAt(i)) {
notFoundNullElement = false
hs.addNull()
} else {
val elem = array2.get(i, elementType)
hs.add(elem)
withArray2NaNCheckFunc(elem)
}
i += 1
}
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
i = 0
while (i < array1.numElements()) {
if (array1.isNullAt(i)) {
if (notFoundNullElement) {
if (!hs.containsNull()) {
arrayBuffer += null
notFoundNullElement = false
hs.addNull()
}
} else {
val elem = array1.get(i, elementType)
if (!hs.contains(elem)) {
arrayBuffer += elem
hs.add(elem)
}
withArray1NaNCheckFunc(elem)
}
i += 1
}
Expand Down Expand Up @@ -4183,10 +4189,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
val ptName = CodeGenerator.primitiveTypeName(jt)

nullSafeCodeGen(ctx, ev, (array1, array2) => {
val notFoundNullElement = ctx.freshName("notFoundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
val openHashSet = classOf[OpenHashSet[_]].getName
val openHashSet = classOf[SQLOpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
Expand All @@ -4197,7 +4202,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array2.isNullAt($i)) {
| $notFoundNullElement = false;
| if (!$hashSet.containsNull()) {
| $hashSet.addNull();
| }
|} else {
| $body
|}
Expand All @@ -4215,18 +4222,18 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
}

val writeArray2ToHashSet = withArray2NullCheck(
s"""
|$jt $value = ${genGetValue(array2, i)};
|$hashSet.add$hsPostFix($hsValueCast$value);
""".stripMargin)
s"$jt $value = ${genGetValue(array2, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
s"$hashSet.add$hsPostFix($hsValueCast$value);",
(valueNaN: Any) => ""))

def withArray1NullAssignment(body: String) =
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array1.isNullAt($i)) {
| if ($notFoundNullElement) {
| if (!$hashSet.containsNull()) {
| $hashSet.addNull();
| $nullElementIndex = $size;
| $notFoundNullElement = false;
| $size++;
| $builder.$$plus$$eq($nullValueHolder);
| }
Expand All @@ -4238,22 +4245,29 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
body
}

val processArray1 = withArray1NullAssignment(
val body =
s"""
|$jt $value = ${genGetValue(array1, i)};
|if (!$hashSet.contains($hsValueCast$value)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
| }
| $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
|}
""".stripMargin)
""".stripMargin

val processArray1 = withArray1NullAssignment(
s"$jt $value = ${genGetValue(array1, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
(valueNaN: String) =>
s"""
|$size++;
|$builder.$$plus$$eq($valueNaN);
""".stripMargin))

// Only need to track null element index when array1's element is nullable.
val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|boolean $notFoundNullElement = true;
|int $nullElementIndex = -1;
""".stripMargin
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2327,6 +2327,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(Float.NaN, null, 1f))
}

test("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayExcept(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))),
Seq(1d))
checkEvaluation(ArrayExcept(
Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)),
Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType))),
Seq(1d))
checkEvaluation(ArrayExcept(
Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN))),
Seq(1f))
checkEvaluation(ArrayExcept(
Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)),
Literal.create(Seq(Float.NaN, null), ArrayType(FloatType))),
Seq(1f))
}

test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayDistinct(
Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), ArrayType(DoubleType))),
Expand Down