Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dc9d6f0
initial commit
kiszk Apr 13, 2018
3019840
update description
kiszk Apr 13, 2018
8cee6cf
fix test failure
kiszk Apr 13, 2018
2041ec4
address review comments
kiszk Apr 17, 2018
8c2280b
introduce ArraySetUtils to reuse code among array_union/array_interse…
kiszk Apr 17, 2018
b3a3132
fix python test failure
kiszk Apr 18, 2018
a2c7dd1
fix python test failure
kiszk Apr 18, 2018
5313680
simplification
kiszk Apr 18, 2018
98f8d1f
fix pyspark test failure
kiszk Apr 19, 2018
30ee7fc
address review comments
kiszk Apr 20, 2018
cd347e9
add new tests based on review comment
kiszk Apr 20, 2018
d2eaee3
fix mistakes in rebase
kiszk Apr 20, 2018
2ddeb06
fix unexpected changes
kiszk Apr 20, 2018
71b31f0
merge changes in #21103
kiszk Apr 20, 2018
7e71340
use GenericArrayData if UnsafeArrayData cannot be used
kiszk May 4, 2018
04c97c3
use BinaryArrayExpressionWithImplicitCast
kiszk May 4, 2018
401ca7a
update test cases
kiszk May 4, 2018
15b953b
rebase with master
kiszk May 17, 2018
f050922
support complex types
kiszk May 18, 2018
8a27667
add test cases with duplication in an array
kiszk May 19, 2018
e50bc55
rebase with master
kiszk Jun 1, 2018
7e3f2ef
address review comments
kiszk Jun 1, 2018
e5401e7
address review comment
kiszk Jun 1, 2018
3e21e48
keep the order of input array elements
kiszk Jun 10, 2018
3c39506
address review comments
kiszk Jun 20, 2018
6654742
fix scala style error
kiszk Jun 20, 2018
be9f331
address review comment
kiszk Jun 20, 2018
90e84b3
address review comments
kiszk Jun 22, 2018
6f721f0
address review comments
kiszk Jun 22, 2018
0c0d3ba
address review comments
kiszk Jul 8, 2018
4a217bc
cleanup
kiszk Jul 8, 2018
f5ebbe8
eliminate duplicated code
kiszk Jul 8, 2018
763a1f8
address review comments
kiszk Jul 9, 2018
7b51564
address review comment
kiszk Jul 11, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
introduce ArraySetUtils to reuse code among array_union/array_interse…
…ct/array_except
  • Loading branch information
kiszk committed Jun 27, 2018
commit 8c2280be254769b51342c6afa41801e88c6b0bee
Original file line number Diff line number Diff line change
Expand Up @@ -2861,143 +2861,6 @@ case class ArrayRepeat(left: Expression, right: Expression)
}
}

/**
* Returns an array of the elements in the union of x and y, without duplicates
*/
@ExpressionDescription(
usage = """
_FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2,
without duplicates.
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
array(1, 2, 3, 5)
""",
since = "2.4.0")
case class ArrayUnion(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)

override def checkInputDataTypes(): TypeCheckResult = {
val r = super.checkInputDataTypes()
if ((r == TypeCheckResult.TypeCheckSuccess) &&
(left.dataType.asInstanceOf[ArrayType].elementType !=
right.dataType.asInstanceOf[ArrayType].elementType)) {
TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same")
} else {
r
}
}

override def dataType: DataType = left.dataType

private def elementType = dataType.asInstanceOf[ArrayType].elementType
private def cnLeft = left.dataType.asInstanceOf[ArrayType].containsNull
private def cnRight = right.dataType.asInstanceOf[ArrayType].containsNull

override def nullSafeEval(linput: Any, rinput: Any): Any = {
val larray = linput.asInstanceOf[ArrayData]
val rarray = rinput.asInstanceOf[ArrayData]

if (!cnLeft && !cnRight) {
elementType match {
case IntegerType =>
// avoid boxing of primitive int array elements
val hs = new OpenHashSet[Int]
var i = 0
while (i < larray.numElements()) {
hs.add(larray.getInt(i))
i += 1
}
i = 0
while (i < rarray.numElements()) {
hs.add(rarray.getInt(i))
i += 1
}
UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
case LongType =>
// avoid boxing of primitive long array elements
val hs = new OpenHashSet[Long]
var i = 0
while (i < larray.numElements()) {
hs.add(larray.getLong(i))
i += 1
}
i = 0
while (i < rarray.numElements()) {
hs.add(rarray.getLong(i))
i += 1
}
UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
case _ =>
val hs = new OpenHashSet[Any]
var i = 0
while (i < larray.numElements()) {
hs.add(larray.get(i, elementType))
i += 1
}
i = 0
while (i < rarray.numElements()) {
hs.add(rarray.get(i, elementType))
i += 1
}
new GenericArrayData(hs.iterator.toArray)
}
} else {
CollectionOperations.arrayUnion(larray, rarray, elementType)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val hs = ctx.freshName("hs")
val i = ctx.freshName("i")
val collectionOperations = "org.apache.spark.sql.catalyst.expressions.CollectionOperations"
val genericArrayData = classOf[GenericArrayData].getName
val unsafeArrayData = classOf[UnsafeArrayData].getName
val openHashSet = classOf[OpenHashSet[_]].getName
val ot = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)"
val (postFix, classTag, getter, arrayBuilder, castType) = if (!cnLeft && !cnRight) {
val ptName = CodeGenerator.primitiveTypeName(elementType)
elementType match {
case ByteType | ShortType | IntegerType =>
(s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)",
s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType))
case LongType =>
(s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)",
s"$unsafeArrayData.fromPrimitiveArray", "long")
case _ =>
("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $ot)",
s"new $genericArrayData", "Object")
}
} else {
("", "", "", "", "")
}

nullSafeCodeGen(ctx, ev, (larray, rarray) => {
if (classTag != "") {
s"""
|$openHashSet $hs = new $openHashSet$postFix($classTag);
|for (int $i = 0; $i < $larray.numElements(); $i++) {
| $hs.add$postFix($larray.$getter);
|}
|for (int $i = 0; $i < $rarray.numElements(); $i++) {
| $hs.add$postFix($rarray.$getter);
|}
|${ev.value} = $arrayBuilder(
| ($castType[]) $hs.iterator().toArray($classTag));
""".stripMargin
} else {
val dt = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)"
s"${ev.value} = $collectionOperations$$.MODULE$$.arrayUnion($larray, $rarray, $ot);"
}
})
}

override def prettyName: String = "array_union"
}

/**
* Remove all elements that equal to element from the given array
*/
Expand Down Expand Up @@ -3399,9 +3262,159 @@ case class ArrayDistinct(child: Expression)
override def prettyName: String = "array_distinct"
}

object CollectionOperations {
def arrayUnion(larray: ArrayData, rarray: ArrayData, et: DataType): ArrayData = {
new GenericArrayData(larray.toArray[AnyRef](et).union(rarray.toArray[AnyRef](et))
abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes {
val kindUnion = 1

def typeId: Int

def array1: Expression

def array2: Expression

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)

override def checkInputDataTypes(): TypeCheckResult = {
val r = super.checkInputDataTypes()
if ((r == TypeCheckResult.TypeCheckSuccess) &&
(array1.dataType.asInstanceOf[ArrayType].elementType !=
array2.dataType.asInstanceOf[ArrayType].elementType)) {
TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same")
} else {
r
}
}

override def dataType: DataType = array1.dataType

private def elementType = dataType.asInstanceOf[ArrayType].elementType

private def cn1 = array1.dataType.asInstanceOf[ArrayType].containsNull

private def cn2 = array2.dataType.asInstanceOf[ArrayType].containsNull

override def nullSafeEval(input1: Any, input2: Any): Any = {
val ary1 = input1.asInstanceOf[ArrayData]
val ary2 = input2.asInstanceOf[ArrayData]

if (!cn1 && !cn2) {
elementType match {
case IntegerType =>
// avoid boxing of primitive int array elements
val hs = new OpenHashSet[Int]
var i = 0
while (i < ary1.numElements()) {
hs.add(ary1.getInt(i))
i += 1
}
i = 0
while (i < ary2.numElements()) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also support array_union and array_except by changing this 2nd loop with small other changes. This is why we introduced ArraySetUtils in this PR.

Other PRs will update ArraySetUtils appropriately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an instance of the usage of ArraySetUtils.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the final version of ArraySetUtils that supports three functions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this abstraction is good or not. The final version seems complex because of a bunch of if-else.
I'd rather introduce abstract methods for the difference and override them in the subclasses.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your good suggestion.
I will create a new abstract method for this part which will be overridden by each of three subclasses

hs.add(ary2.getInt(i))
i += 1
}
UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
case LongType =>
// avoid boxing of primitive long array elements
val hs = new OpenHashSet[Long]
var i = 0
while (i < ary1.numElements()) {
hs.add(ary1.getLong(i))
i += 1
}
i = 0
while (i < ary2.numElements()) {
hs.add(ary2.getLong(i))
i += 1
}
UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
case _ =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case still doesn't work for BinaryType?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. To address this comment can fix this issue.

val hs = new OpenHashSet[Any]
var i = 0
while (i < ary1.numElements()) {
hs.add(ary1.get(i, elementType))
i += 1
}
i = 0
while (i < ary2.numElements()) {
hs.add(ary2.get(i, elementType))
i += 1
}
new GenericArrayData(hs.iterator.toArray)
}
} else {
ArraySetUtils.arrayUnion(ary1, ary2, elementType)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val hs = ctx.freshName("hs")
val i = ctx.freshName("i")
val ArraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils"
val genericArrayData = classOf[GenericArrayData].getName
val unsafeArrayData = classOf[UnsafeArrayData].getName
val openHashSet = classOf[OpenHashSet[_]].getName
val ot = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)"
val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) {
val ptName = CodeGenerator.primitiveTypeName(elementType)
elementType match {
case ByteType | ShortType | IntegerType =>
(s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)",
s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType))
case LongType =>
(s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)",
s"$unsafeArrayData.fromPrimitiveArray", "long")
case _ =>
("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $ot)",
s"new $genericArrayData", "Object")
}
} else {
("", "", "", "", "")
}

nullSafeCodeGen(ctx, ev, (ary1, ary2) => {
if (classTag != "") {
s"""
|$openHashSet $hs = new $openHashSet$postFix($classTag);
|for (int $i = 0; $i < $ary1.numElements(); $i++) {
| $hs.add$postFix($ary1.$getter);
|}
|for (int $i = 0; $i < $ary2.numElements(); $i++) {
| $hs.add$postFix($ary2.$getter);
|}
|${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we shouldn't use iterator() to avoid box/unbox. Iterator is not specialized.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, great catch. I confirmed there is not iterator(), which is specialized, in OpenHashSet$mcI$sp`. I will rewrite this.

""".stripMargin
} else {
val dt = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)"
s"${ev.value} = $ArraySetUtils$$.MODULE$$.arrayUnion($ary1, $ary2, $ot);"
}
})
}
}

object ArraySetUtils {
def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = {
new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et))
.distinct.asInstanceOf[Array[Any]])
}
}

/**
* Returns an array of the elements in the union of x and y, without duplicates
*/
@ExpressionDescription(
usage = """
_FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2,
without duplicates. The order of elements in the result is not determined.
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
array(1, 2, 3, 5)
""",
since = "2.4.0")
case class ArrayUnion(left: Expression, right: Expression) extends ArraySetUtils {
override def typeId: Int = kindUnion
override def array1: Expression = left
override def array2: Expression = right

override def prettyName: String = "array_union"
}
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
"argument 1 requires (array or map) type, however, '`_1`' is of string type"))
}

test("array union functions") {
test("array_union functions") {
val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b")
val ans1 = Row(Seq(4, 1, 3, 2))
checkAnswer(df1.select(array_union($"a", $"b")), ans1)
Expand Down