-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23914][SQL] Add array_union function #21061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
dc9d6f0
3019840
8cee6cf
2041ec4
8c2280b
b3a3132
a2c7dd1
5313680
98f8d1f
30ee7fc
cd347e9
d2eaee3
2ddeb06
71b31f0
7e71340
04c97c3
401ca7a
15b953b
f050922
8a27667
e50bc55
7e3f2ef
e5401e7
3e21e48
3c39506
6654742
be9f331
90e84b3
6f721f0
0c0d3ba
4a217bc
f5ebbe8
763a1f8
7b51564
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…ct/array_except
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2861,143 +2861,6 @@ case class ArrayRepeat(left: Expression, right: Expression) | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns an array of the elements in the union of x and y, without duplicates | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = """ | ||
| _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, | ||
| without duplicates. | ||
| """, | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); | ||
| array(1, 2, 3, 5) | ||
| """, | ||
| since = "2.4.0") | ||
| case class ArrayUnion(left: Expression, right: Expression) | ||
| extends BinaryExpression with ExpectsInputTypes { | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| val r = super.checkInputDataTypes() | ||
| if ((r == TypeCheckResult.TypeCheckSuccess) && | ||
| (left.dataType.asInstanceOf[ArrayType].elementType != | ||
| right.dataType.asInstanceOf[ArrayType].elementType)) { | ||
| TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") | ||
| } else { | ||
| r | ||
| } | ||
| } | ||
|
|
||
| override def dataType: DataType = left.dataType | ||
|
|
||
| private def elementType = dataType.asInstanceOf[ArrayType].elementType | ||
| private def cnLeft = left.dataType.asInstanceOf[ArrayType].containsNull | ||
| private def cnRight = right.dataType.asInstanceOf[ArrayType].containsNull | ||
|
|
||
| override def nullSafeEval(linput: Any, rinput: Any): Any = { | ||
| val larray = linput.asInstanceOf[ArrayData] | ||
| val rarray = rinput.asInstanceOf[ArrayData] | ||
|
|
||
| if (!cnLeft && !cnRight) { | ||
| elementType match { | ||
| case IntegerType => | ||
| // avoid boxing of primitive int array elements | ||
| val hs = new OpenHashSet[Int] | ||
| var i = 0 | ||
| while (i < larray.numElements()) { | ||
| hs.add(larray.getInt(i)) | ||
| i += 1 | ||
| } | ||
| i = 0 | ||
| while (i < rarray.numElements()) { | ||
| hs.add(rarray.getInt(i)) | ||
| i += 1 | ||
| } | ||
| UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) | ||
| case LongType => | ||
| // avoid boxing of primitive long array elements | ||
| val hs = new OpenHashSet[Long] | ||
| var i = 0 | ||
| while (i < larray.numElements()) { | ||
| hs.add(larray.getLong(i)) | ||
| i += 1 | ||
| } | ||
| i = 0 | ||
| while (i < rarray.numElements()) { | ||
| hs.add(rarray.getLong(i)) | ||
| i += 1 | ||
| } | ||
| UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) | ||
| case _ => | ||
| val hs = new OpenHashSet[Any] | ||
| var i = 0 | ||
| while (i < larray.numElements()) { | ||
| hs.add(larray.get(i, elementType)) | ||
| i += 1 | ||
| } | ||
| i = 0 | ||
| while (i < rarray.numElements()) { | ||
| hs.add(rarray.get(i, elementType)) | ||
| i += 1 | ||
| } | ||
| new GenericArrayData(hs.iterator.toArray) | ||
| } | ||
| } else { | ||
| CollectionOperations.arrayUnion(larray, rarray, elementType) | ||
| } | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val hs = ctx.freshName("hs") | ||
| val i = ctx.freshName("i") | ||
| val collectionOperations = "org.apache.spark.sql.catalyst.expressions.CollectionOperations" | ||
| val genericArrayData = classOf[GenericArrayData].getName | ||
| val unsafeArrayData = classOf[UnsafeArrayData].getName | ||
| val openHashSet = classOf[OpenHashSet[_]].getName | ||
| val ot = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" | ||
| val (postFix, classTag, getter, arrayBuilder, castType) = if (!cnLeft && !cnRight) { | ||
| val ptName = CodeGenerator.primitiveTypeName(elementType) | ||
| elementType match { | ||
| case ByteType | ShortType | IntegerType => | ||
| (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", | ||
| s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) | ||
| case LongType => | ||
| (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", | ||
| s"$unsafeArrayData.fromPrimitiveArray", "long") | ||
| case _ => | ||
| ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $ot)", | ||
| s"new $genericArrayData", "Object") | ||
| } | ||
| } else { | ||
| ("", "", "", "", "") | ||
| } | ||
|
|
||
| nullSafeCodeGen(ctx, ev, (larray, rarray) => { | ||
| if (classTag != "") { | ||
| s""" | ||
| |$openHashSet $hs = new $openHashSet$postFix($classTag); | ||
| |for (int $i = 0; $i < $larray.numElements(); $i++) { | ||
| | $hs.add$postFix($larray.$getter); | ||
| |} | ||
| |for (int $i = 0; $i < $rarray.numElements(); $i++) { | ||
| | $hs.add$postFix($rarray.$getter); | ||
| |} | ||
| |${ev.value} = $arrayBuilder( | ||
| | ($castType[]) $hs.iterator().toArray($classTag)); | ||
| """.stripMargin | ||
| } else { | ||
| val dt = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" | ||
| s"${ev.value} = $collectionOperations$$.MODULE$$.arrayUnion($larray, $rarray, $ot);" | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| override def prettyName: String = "array_union" | ||
| } | ||
|
|
||
| /** | ||
| * Remove all elements that equal to element from the given array | ||
| */ | ||
|
|
@@ -3399,9 +3262,159 @@ case class ArrayDistinct(child: Expression) | |
| override def prettyName: String = "array_distinct" | ||
| } | ||
|
|
||
| object CollectionOperations { | ||
| def arrayUnion(larray: ArrayData, rarray: ArrayData, et: DataType): ArrayData = { | ||
| new GenericArrayData(larray.toArray[AnyRef](et).union(rarray.toArray[AnyRef](et)) | ||
| abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { | ||
| val kindUnion = 1 | ||
|
|
||
| def typeId: Int | ||
|
|
||
| def array1: Expression | ||
|
|
||
| def array2: Expression | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| val r = super.checkInputDataTypes() | ||
| if ((r == TypeCheckResult.TypeCheckSuccess) && | ||
| (array1.dataType.asInstanceOf[ArrayType].elementType != | ||
| array2.dataType.asInstanceOf[ArrayType].elementType)) { | ||
| TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") | ||
| } else { | ||
| r | ||
| } | ||
| } | ||
|
|
||
| override def dataType: DataType = array1.dataType | ||
|
|
||
| private def elementType = dataType.asInstanceOf[ArrayType].elementType | ||
|
|
||
| private def cn1 = array1.dataType.asInstanceOf[ArrayType].containsNull | ||
|
|
||
| private def cn2 = array2.dataType.asInstanceOf[ArrayType].containsNull | ||
|
|
||
| override def nullSafeEval(input1: Any, input2: Any): Any = { | ||
| val ary1 = input1.asInstanceOf[ArrayData] | ||
| val ary2 = input2.asInstanceOf[ArrayData] | ||
|
|
||
| if (!cn1 && !cn2) { | ||
| elementType match { | ||
| case IntegerType => | ||
| // avoid boxing of primitive int array elements | ||
| val hs = new OpenHashSet[Int] | ||
| var i = 0 | ||
| while (i < ary1.numElements()) { | ||
| hs.add(ary1.getInt(i)) | ||
| i += 1 | ||
| } | ||
| i = 0 | ||
| while (i < ary2.numElements()) { | ||
| hs.add(ary2.getInt(i)) | ||
| i += 1 | ||
| } | ||
| UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) | ||
| case LongType => | ||
| // avoid boxing of primitive long array elements | ||
| val hs = new OpenHashSet[Long] | ||
| var i = 0 | ||
| while (i < ary1.numElements()) { | ||
| hs.add(ary1.getLong(i)) | ||
| i += 1 | ||
| } | ||
| i = 0 | ||
| while (i < ary2.numElements()) { | ||
| hs.add(ary2.getLong(i)) | ||
| i += 1 | ||
| } | ||
| UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) | ||
| case _ => | ||
|
||
| val hs = new OpenHashSet[Any] | ||
| var i = 0 | ||
| while (i < ary1.numElements()) { | ||
| hs.add(ary1.get(i, elementType)) | ||
| i += 1 | ||
| } | ||
| i = 0 | ||
| while (i < ary2.numElements()) { | ||
| hs.add(ary2.get(i, elementType)) | ||
| i += 1 | ||
| } | ||
| new GenericArrayData(hs.iterator.toArray) | ||
| } | ||
| } else { | ||
| ArraySetUtils.arrayUnion(ary1, ary2, elementType) | ||
| } | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val hs = ctx.freshName("hs") | ||
| val i = ctx.freshName("i") | ||
| val ArraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" | ||
| val genericArrayData = classOf[GenericArrayData].getName | ||
| val unsafeArrayData = classOf[UnsafeArrayData].getName | ||
| val openHashSet = classOf[OpenHashSet[_]].getName | ||
| val ot = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" | ||
| val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) { | ||
| val ptName = CodeGenerator.primitiveTypeName(elementType) | ||
| elementType match { | ||
| case ByteType | ShortType | IntegerType => | ||
| (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", | ||
| s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) | ||
| case LongType => | ||
| (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", | ||
| s"$unsafeArrayData.fromPrimitiveArray", "long") | ||
| case _ => | ||
| ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $ot)", | ||
| s"new $genericArrayData", "Object") | ||
| } | ||
| } else { | ||
| ("", "", "", "", "") | ||
| } | ||
|
|
||
| nullSafeCodeGen(ctx, ev, (ary1, ary2) => { | ||
| if (classTag != "") { | ||
| s""" | ||
| |$openHashSet $hs = new $openHashSet$postFix($classTag); | ||
| |for (int $i = 0; $i < $ary1.numElements(); $i++) { | ||
| | $hs.add$postFix($ary1.$getter); | ||
| |} | ||
| |for (int $i = 0; $i < $ary2.numElements(); $i++) { | ||
| | $hs.add$postFix($ary2.$getter); | ||
| |} | ||
| |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); | ||
|
||
| """.stripMargin | ||
| } else { | ||
| val dt = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" | ||
| s"${ev.value} = $ArraySetUtils$$.MODULE$$.arrayUnion($ary1, $ary2, $ot);" | ||
| } | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| object ArraySetUtils { | ||
| def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { | ||
| new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) | ||
| .distinct.asInstanceOf[Array[Any]]) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns an array of the elements in the union of x and y, without duplicates | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = """ | ||
| _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, | ||
| without duplicates. The order of elements in the result is not determined. | ||
| """, | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); | ||
| array(1, 2, 3, 5) | ||
| """, | ||
| since = "2.4.0") | ||
| case class ArrayUnion(left: Expression, right: Expression) extends ArraySetUtils { | ||
| override def typeId: Int = kindUnion | ||
| override def array1: Expression = left | ||
| override def array2: Expression = right | ||
|
|
||
| override def prettyName: String = "array_union" | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also support
array_unionandarray_exceptby changing this 2nd loop with small other changes. This is why we introducedArraySetUtilsin this PR.Other PRs will update
ArraySetUtilsappropriately.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an instance of the usage of
ArraySetUtils.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the final version of
ArraySetUtilsthat supports three functions.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this abstraction is good or not. The final version seems complex because of a bunch of if-else.
I'd rather introduce abstract methods for the difference and override them in the subclasses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your good suggestion.
I will create a new abstract method for this part which will be overridden by each of three subclasses