Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dc9d6f0
initial commit
kiszk Apr 13, 2018
3019840
update description
kiszk Apr 13, 2018
8cee6cf
fix test failure
kiszk Apr 13, 2018
2041ec4
address review comments
kiszk Apr 17, 2018
8c2280b
introduce ArraySetUtils to reuse code among array_union/array_interse…
kiszk Apr 17, 2018
b3a3132
fix python test failure
kiszk Apr 18, 2018
a2c7dd1
fix python test failure
kiszk Apr 18, 2018
5313680
simplification
kiszk Apr 18, 2018
98f8d1f
fix pyspark test failure
kiszk Apr 19, 2018
30ee7fc
address review comments
kiszk Apr 20, 2018
cd347e9
add new tests based on review comment
kiszk Apr 20, 2018
d2eaee3
fix mistakes in rebase
kiszk Apr 20, 2018
2ddeb06
fix unexpected changes
kiszk Apr 20, 2018
71b31f0
merge changes in #21103
kiszk Apr 20, 2018
7e71340
use GenericArrayData if UnsafeArrayData cannot be used
kiszk May 4, 2018
04c97c3
use BinaryArrayExpressionWithImplicitCast
kiszk May 4, 2018
401ca7a
update test cases
kiszk May 4, 2018
15b953b
rebase with master
kiszk May 17, 2018
f050922
support complex types
kiszk May 18, 2018
8a27667
add test cases with duplication in an array
kiszk May 19, 2018
e50bc55
rebase with master
kiszk Jun 1, 2018
7e3f2ef
address review comments
kiszk Jun 1, 2018
e5401e7
address review comment
kiszk Jun 1, 2018
3e21e48
keep the order of input array elements
kiszk Jun 10, 2018
3c39506
address review comments
kiszk Jun 20, 2018
6654742
fix scala style error
kiszk Jun 20, 2018
be9f331
address review comment
kiszk Jun 20, 2018
90e84b3
address review comments
kiszk Jun 22, 2018
6f721f0
address review comments
kiszk Jun 22, 2018
0c0d3ba
address review comments
kiszk Jul 8, 2018
4a217bc
cleanup
kiszk Jul 8, 2018
f5ebbe8
eliminate duplicated code
kiszk Jul 8, 2018
763a1f8
address review comments
kiszk Jul 9, 2018
7b51564
address review comment
kiszk Jul 11, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address review comments
  • Loading branch information
kiszk committed Jun 27, 2018
commit 2041ec45efdcb2b3ae9dfc7c5b7c6dc26c0091ea
7 changes: 4 additions & 3 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1942,6 +1942,7 @@ def concat(*cols):
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))


@ignore_unicode_prefix
@since(2.4)
def array_position(col, value):
"""
Expand Down Expand Up @@ -2017,16 +2018,16 @@ def array_distinct(col):
@since(2.4)
def array_union(col1, col2):
"""
Collection function: Returns an array of the elements in the union of col1 and col2,
without duplicates
Collection function: returns an array of the elements in the union of col1 and col2,
Copy link
Member

@viirya viirya Jul 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the array of col1 contains duplicate elements itself, what it does? de-duplicate them too?

E.g.,

df = spark.createDataFrame([Row(c1=["b", "a", "c", "c"], c2=["c", "d", "a", "f"])])
df.select(array_union(df.c1, df.c2)).collect()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the code, seems it de-duplicates all elements from two arrays. Is this behavior the same as Presto?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add the tests for duplication.
Yes, this will de-duplicate. I think that it is the same behavior as Presto.

without duplicates. The order of elements in the result is not determined.

:param col1: name of column containing array
:param col2: name of column containing array

>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
>>> df.select(array_union(df.c1, df.c2)).collect()
[Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f']))]
[Row(array_union(c1, c2)=[u'b', u'c', u'd', u'a', u'f']))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2876,7 +2876,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
""",
since = "2.4.0")
case class ArrayUnion(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
extends BinaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)

Expand All @@ -2893,46 +2893,106 @@ case class ArrayUnion(left: Expression, right: Expression)

override def dataType: DataType = left.dataType

private def elementType = dataType.asInstanceOf[ArrayType].elementType
private def cnLeft = left.dataType.asInstanceOf[ArrayType].containsNull
private def cnRight = right.dataType.asInstanceOf[ArrayType].containsNull

override def nullSafeEval(linput: Any, rinput: Any): Any = {
val elementType = dataType.asInstanceOf[ArrayType].elementType
val cnl = left.dataType.asInstanceOf[ArrayType].containsNull
val cnr = right.dataType.asInstanceOf[ArrayType].containsNull
val larray = linput.asInstanceOf[ArrayData]
val rarray = rinput.asInstanceOf[ArrayData]

if (!cnl && !cnr && elementType == IntegerType) {
// avoid boxing primitive int array elements
val hs = new OpenHashSet[Int]
var i = 0
while (i < larray.numElements()) {
hs.add(larray.getInt(i))
i += 1
}
i = 0
while (i < rarray.numElements()) {
hs.add(rarray.getInt(i))
i += 1
}
UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
} else if (!cnl && !cnr && elementType == LongType) {
// avoid boxing of primitive long array elements
val hs = new OpenHashSet[Long]
var i = 0
while (i < larray.numElements()) {
hs.add(larray.getLong(i))
i += 1
if (!cnLeft && !cnRight) {
elementType match {
case IntegerType =>
// avoid boxing of primitive int array elements
val hs = new OpenHashSet[Int]
var i = 0
while (i < larray.numElements()) {
hs.add(larray.getInt(i))
i += 1
}
i = 0
while (i < rarray.numElements()) {
hs.add(rarray.getInt(i))
i += 1
}
UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
case LongType =>
// avoid boxing of primitive long array elements
val hs = new OpenHashSet[Long]
var i = 0
while (i < larray.numElements()) {
hs.add(larray.getLong(i))
i += 1
}
i = 0
while (i < rarray.numElements()) {
hs.add(rarray.getLong(i))
i += 1
}
UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
case _ =>
val hs = new OpenHashSet[Any]
var i = 0
while (i < larray.numElements()) {
hs.add(larray.get(i, elementType))
i += 1
}
i = 0
while (i < rarray.numElements()) {
hs.add(rarray.get(i, elementType))
i += 1
}
new GenericArrayData(hs.iterator.toArray)
}
i = 0
while (i < rarray.numElements()) {
hs.add(rarray.getLong(i))
i += 1
} else {
CollectionOperations.arrayUnion(larray, rarray, elementType)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val hs = ctx.freshName("hs")
val i = ctx.freshName("i")
val collectionOperations = "org.apache.spark.sql.catalyst.expressions.CollectionOperations"
val genericArrayData = classOf[GenericArrayData].getName
val unsafeArrayData = classOf[UnsafeArrayData].getName
val openHashSet = classOf[OpenHashSet[_]].getName
val ot = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)"
val (postFix, classTag, getter, arrayBuilder, castType) = if (!cnLeft && !cnRight) {
val ptName = CodeGenerator.primitiveTypeName(elementType)
elementType match {
case ByteType | ShortType | IntegerType =>
(s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)",
s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType))
case LongType =>
(s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)",
s"$unsafeArrayData.fromPrimitiveArray", "long")
case _ =>
("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $ot)",
s"new $genericArrayData", "Object")
}
UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
} else {
new GenericArrayData(
(larray.toArray[AnyRef](elementType) union rarray.toArray[AnyRef](elementType))
.distinct.asInstanceOf[Array[Any]])
("", "", "", "", "")
}

nullSafeCodeGen(ctx, ev, (larray, rarray) => {
if (classTag != "") {
s"""
|$openHashSet $hs = new $openHashSet$postFix($classTag);
|for (int $i = 0; $i < $larray.numElements(); $i++) {
| $hs.add$postFix($larray.$getter);
|}
|for (int $i = 0; $i < $rarray.numElements(); $i++) {
| $hs.add$postFix($rarray.$getter);
|}
|${ev.value} = $arrayBuilder(
| ($castType[]) $hs.iterator().toArray($classTag));
""".stripMargin
} else {
val dt = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)"
s"${ev.value} = $collectionOperations$$.MODULE$$.arrayUnion($larray, $rarray, $ot);"
}
})
}

override def prettyName: String = "array_union"
Expand Down Expand Up @@ -3338,3 +3398,10 @@ case class ArrayDistinct(child: Expression)

override def prettyName: String = "array_distinct"
}

object CollectionOperations {
def arrayUnion(larray: ArrayData, rarray: ArrayData, et: DataType): ArrayData = {
new GenericArrayData(larray.toArray[AnyRef](et).union(rarray.toArray[AnyRef](et))
.distinct.asInstanceOf[Array[Any]])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType))
val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType))
val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType))
val a23 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false))

val a30 = Literal.create(Seq(null, null), ArrayType(NullType))

Expand All @@ -1201,6 +1203,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f"))
checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g"))
checkEvaluation(ArrayUnion(a23, a24), Seq("b", "c", "d", "a", "f"))

checkEvaluation(ArrayUnion(a30, a30), Seq(null))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if one of the two arguments is null?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I cannot see such a test case in Presto.
Let me think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Umm, when only one of the arguments is null, unexpected TreeNodeException occurs.

checkEvaluation(ArrayUnion(a20, a30), Seq("b", "a", "c", null))
After applying rule org.apache.spark.sql.catalyst.optimizer.EliminateDistinct in batch Eliminate Distinct, the structural integrity of the plan is broken., tree:
'Project [array_union([b,a,c], [null,null]) AS Optimized(array_union([b,a,c], [null,null]))#71]
+- OneRowRelation

org.apache.spark.sql.catalyst.errors.package$TreeNodeException: After applying rule org.apache.spark.sql.catalyst.optimizer.EliminateDistinct in batch Eliminate Distinct, the structural integrity of the plan is broken., tree:
'Project [array_union([b,a,c], [null,null]) AS Optimized(array_union([b,a,c], [null,null]))#71]
+- OneRowRelation


	at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:106)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:84)
	at scala.collection.IndexedSeqOptimized$class.foldl(IndexedSeqOptimized.scala:57)
	at scala.collection.IndexedSeqOptimized$class.foldLeft(IndexedSeqOptimized.scala:66)
	at scala.collection.mutable.WrappedArray.foldLeft(WrappedArray.scala:35)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:84)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:76)
	at scala.collection.immutable.List.foreach(List.scala:381)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:76)
	at org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper$class.checkEvaluationWithOptimization(ExpressionEvalHelper.scala:252)
...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure its reason (maybe because the element types are different?), but I meant something like:

checkEvaluation(ArrayUnion(a20, Literal.create(null, ArrayType(StringType))), ...?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Thanks. Your example returns null. Since the following test throws an exception, I think that it makes sense that your example returns null. WDYT?

    val df8 = Seq((null, Array("a"))).toDF("a", "b")
    intercept[AnalysisException] {
      df8.select(array_union($"a", $"b"))
    }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning null sounds good, but what do you mean by "Since the following test throws an exception"? What exception is the test throwing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following error occurs. When I looked at other tests, it does not look strange. This is because null has no type information.

cannot resolve 'array_union(NULL, `b`)' due to data type mismatch: Element type in both arrays must be the same;;
'Project [array_union(null, b#118) AS array_union(a, b)#121]
+- AnalysisBarrier
      +- Project [_1#114 AS a#117, _2#115 AS b#118]
         +- LocalRelation [_1#114, _2#115]

org.apache.spark.sql.AnalysisException: cannot resolve 'array_union(NULL, `b`)' due to data type mismatch: Element type in both arrays must be the same;;
'Project [array_union(null, b#118) AS array_union(a, b)#121]
+- AnalysisBarrier
      +- Project [_1#114 AS a#117, _2#115 AS b#118]
         +- LocalRelation [_1#114, _2#115]

	at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:93)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:85)
...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Maybe the purpose of the test is not what I thought.
Seems like what I wanted is included in the latest updates.

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3199,6 +3199,7 @@ object functions {

/**
* Returns an array of the elements in the union of the given two arrays, without duplicates.
* The order of elements in the result is not determined
*
* @group collection_funcs
* @since 2.4.0
Expand Down