@@ -39,7 +39,6 @@ import org.apache.spark.unsafe.UTF8StringBuilder
3939import org .apache .spark .unsafe .array .ByteArrayMethods
4040import org .apache .spark .unsafe .array .ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH
4141import org .apache .spark .unsafe .types .{ByteArray , CalendarInterval , UTF8String }
42- import org .apache .spark .util .collection .OpenHashSet
4342
4443/**
4544 * Base trait for [[BinaryExpression ]]s with two arrays of the same element type and implicit
@@ -4109,32 +4108,38 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
41094108 @ transient lazy val evalExcept : (ArrayData , ArrayData ) => ArrayData = {
41104109 if (TypeUtils .typeWithProperEquals(elementType)) {
41114110 (array1, array2) =>
4112- val hs = new OpenHashSet [Any ]
4113- var notFoundNullElement = true
4111+ val hs = new SQLOpenHashSet [Any ]
4112+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer [Any ]
4113+ val withArray2NaNCheckFunc = SQLOpenHashSet .withNaNCheckFunc(elementType, hs,
4114+ (value : Any ) => hs.add(value),
4115+ (valueNaN : Any ) => {})
4116+ val withArray1NaNCheckFunc = SQLOpenHashSet .withNaNCheckFunc(elementType, hs,
4117+ (value : Any ) =>
4118+ if (! hs.contains(value)) {
4119+ arrayBuffer += value
4120+ hs.add(value)
4121+ },
4122+ (valueNaN : Any ) => arrayBuffer += valueNaN)
41144123 var i = 0
41154124 while (i < array2.numElements()) {
41164125 if (array2.isNullAt(i)) {
4117- notFoundNullElement = false
4126+ hs.addNull()
41184127 } else {
41194128 val elem = array2.get(i, elementType)
4120- hs.add (elem)
4129+ withArray2NaNCheckFunc (elem)
41214130 }
41224131 i += 1
41234132 }
4124- val arrayBuffer = new scala.collection.mutable.ArrayBuffer [Any ]
41254133 i = 0
41264134 while (i < array1.numElements()) {
41274135 if (array1.isNullAt(i)) {
4128- if (notFoundNullElement ) {
4136+ if (! hs.containsNull() ) {
41294137 arrayBuffer += null
4130- notFoundNullElement = false
4138+ hs.addNull()
41314139 }
41324140 } else {
41334141 val elem = array1.get(i, elementType)
4134- if (! hs.contains(elem)) {
4135- arrayBuffer += elem
4136- hs.add(elem)
4137- }
4142+ withArray1NaNCheckFunc(elem)
41384143 }
41394144 i += 1
41404145 }
@@ -4203,10 +4208,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
42034208 val ptName = CodeGenerator .primitiveTypeName(jt)
42044209
42054210 nullSafeCodeGen(ctx, ev, (array1, array2) => {
4206- val notFoundNullElement = ctx.freshName(" notFoundNullElement" )
42074211 val nullElementIndex = ctx.freshName(" nullElementIndex" )
42084212 val builder = ctx.freshName(" builder" )
4209- val openHashSet = classOf [OpenHashSet [_]].getName
4213+ val openHashSet = classOf [SQLOpenHashSet [_]].getName
42104214 val classTag = s " scala.reflect.ClassTag $$ .MODULE $$ . $hsTypeName() "
42114215 val hashSet = ctx.freshName(" hashSet" )
42124216 val arrayBuilder = classOf [mutable.ArrayBuilder [_]].getName
@@ -4217,7 +4221,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
42174221 if (left.dataType.asInstanceOf [ArrayType ].containsNull) {
42184222 s """
42194223 |if ( $array2.isNullAt( $i)) {
4220- | $notFoundNullElement = false ;
4224+ | $hashSet .addNull() ;
42214225 |} else {
42224226 | $body
42234227 |}
@@ -4235,18 +4239,18 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
42354239 }
42364240
42374241 val writeArray2ToHashSet = withArray2NullCheck(
4238- s """
4239- | $jt $ value = ${genGetValue(array2, i)} ;
4240- | $hashSet.add $hsPostFix( $hsValueCast$value);
4241- """ .stripMargin )
4242+ s " $jt $value = ${genGetValue(array2, i)} ; " +
4243+ SQLOpenHashSet .withNaNCheckCode(elementType, value, hashSet,
4244+ s " $hashSet.add $hsPostFix( $hsValueCast$value); " ,
4245+ ( valueNaN : Any ) => " " ) )
42424246
42434247 def withArray1NullAssignment (body : String ) =
42444248 if (left.dataType.asInstanceOf [ArrayType ].containsNull) {
42454249 s """
42464250 |if ( $array1.isNullAt( $i)) {
4247- | if ( $notFoundNullElement) {
4251+ | if (! $hashSet.containsNull()) {
4252+ | $hashSet.addNull();
42484253 | $nullElementIndex = $size;
4249- | $notFoundNullElement = false;
42504254 | $size++;
42514255 | $builder. $$ plus $$ eq( $nullValueHolder);
42524256 | }
@@ -4258,22 +4262,29 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
42584262 body
42594263 }
42604264
4261- val processArray1 = withArray1NullAssignment(
4265+ val body =
42624266 s """
4263- | $jt $value = ${genGetValue(array1, i)};
42644267 |if (! $hashSet.contains( $hsValueCast$value)) {
42654268 | if (++ $size > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
42664269 | break;
42674270 | }
42684271 | $hashSet.add $hsPostFix( $hsValueCast$value);
42694272 | $builder. $$ plus $$ eq( $value);
42704273 |}
4271- """ .stripMargin)
4274+ """ .stripMargin
4275+
4276+ val processArray1 = withArray1NullAssignment(
4277+ s " $jt $value = ${genGetValue(array1, i)}; " +
4278+ SQLOpenHashSet .withNaNCheckCode(elementType, value, hashSet, body,
4279+ (valueNaN : String ) =>
4280+ s """
4281+ | $size++;
4282+ | $builder. $$ plus $$ eq( $valueNaN);
4283+ """ .stripMargin))
42724284
42734285 // Only need to track null element index when array1's element is nullable.
42744286 val declareNullTrackVariables = if (left.dataType.asInstanceOf [ArrayType ].containsNull) {
42754287 s """
4276- |boolean $notFoundNullElement = true;
42774288 |int $nullElementIndex = -1;
42784289 """ .stripMargin
42794290 } else {
0 commit comments