Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address review comments
  • Loading branch information
kiszk committed Aug 6, 2018
commit ce755e2b049ca000d6da754654b792e181e6d904
Original file line number Diff line number Diff line change
Expand Up @@ -4043,7 +4043,7 @@ object ArrayUnion {
array2, without duplicates.
""",
examples = """
Examples:Fun
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
array(1, 3)
""",
Expand All @@ -4060,81 +4060,89 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
@transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = {
if (elementTypeSupportEquals) {
(array1, array2) =>
val hs = new OpenHashSet[Any]
val hsResult = new OpenHashSet[Any]
var foundNullElement = false
var i = 0
while (i < array2.numElements()) {
if (array2.isNullAt(i)) {
foundNullElement = true
} else {
val elem = array2.get(i, elementType)
hs.add(elem)
}
i += 1
}
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
i = 0
while (i < array1.numElements()) {
if (array1.isNullAt(i)) {
if (foundNullElement) {
arrayBuffer += null
foundNullElement = false
if (array1.numElements() != 0 && array2.numElements() != 0) {
val hs = new OpenHashSet[Any]
val hsResult = new OpenHashSet[Any]
var foundNullElement = false
var i = 0
while (i < array2.numElements()) {
if (array2.isNullAt(i)) {
foundNullElement = true
} else {
val elem = array2.get(i, elementType)
hs.add(elem)
}
} else {
val elem = array1.get(i, elementType)
if (hs.contains(elem) && !hsResult.contains(elem)) {
arrayBuffer += elem
hsResult.add(elem)
i += 1
}
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
i = 0
while (i < array1.numElements()) {
if (array1.isNullAt(i)) {
if (foundNullElement) {
arrayBuffer += null
foundNullElement = false
}
} else {
val elem = array1.get(i, elementType)
if (hs.contains(elem) && !hsResult.contains(elem)) {
arrayBuffer += elem
hsResult.add(elem)
}
}
i += 1
}
i += 1
new GenericArrayData(arrayBuffer)
} else {
new GenericArrayData(Seq.empty)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Array.empty or Array.emptyObjectArray?

}
new GenericArrayData(arrayBuffer)
} else {
(array1, array2) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var alreadySeenNull = false
var i = 0
while (i < array1.numElements()) {
var found = false
val elem1 = array1.get(i, elementType)
if (array1.isNullAt(i)) {
if (!alreadySeenNull) {
if (array1.numElements() != 0 && array2.numElements() != 0) {
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var alreadySeenNull = false
var i = 0
while (i < array1.numElements()) {
var found = false
val elem1 = array1.get(i, elementType)
if (array1.isNullAt(i)) {
if (!alreadySeenNull) {
var j = 0
while (!found && j < array2.numElements()) {
found = array2.isNullAt(j)
j += 1
}
// array2 is scanned only once for null element
alreadySeenNull = true
}
} else {
var j = 0
while (!found && j < array2.numElements()) {
found = array2.isNullAt(j)
j += 1
}
// array2 is scanned only once for null element
alreadySeenNull = true
}
} else {
var j = 0
while (!found && j < array2.numElements()) {
if (!array2.isNullAt(j)) {
val elem2 = array2.get(j, elementType)
if (ordering.equiv(elem1, elem2)) {
// check whether elem1 is already stored in arrayBuffer
var foundArrayBuffer = false
var k = 0
while (!foundArrayBuffer && k < arrayBuffer.size) {
val va = arrayBuffer(k)
foundArrayBuffer = (va != null) && ordering.equiv(va, elem1)
k += 1
if (!array2.isNullAt(j)) {
val elem2 = array2.get(j, elementType)
if (ordering.equiv(elem1, elem2)) {
// check whether elem1 is already stored in arrayBuffer
var foundArrayBuffer = false
var k = 0
while (!foundArrayBuffer && k < arrayBuffer.size) {
val va = arrayBuffer(k)
foundArrayBuffer = (va != null) && ordering.equiv(va, elem1)
k += 1
}
found = !foundArrayBuffer
}
found = !foundArrayBuffer
}
j += 1
}
j += 1
}
if (found) {
arrayBuffer += elem1
}
i += 1
}
if (found) {
arrayBuffer += elem1
}
i += 1
new GenericArrayData(arrayBuffer)
} else {
new GenericArrayData(Seq.empty)
}
new GenericArrayData(arrayBuffer)
}
}

Expand Down Expand Up @@ -4162,9 +4170,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val hashSetResult = ctx.freshName("hashSetResult")
val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"

def withArray2NullCheck(body: String): String =
if (right.dataType.asInstanceOf[ArrayType].containsNull) {
Expand Down Expand Up @@ -4250,8 +4257,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
|for (int $i = 0; $i < $array2.numElements(); $i++) {
| $writeArray2ToHashSet
|}
|$arrayBuilderClass $builder =
| ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag);
|$arrayBuilderClass $builder = new $arrayBuilderClass();
|int $size = 0;
|for (int $i = 0; $i < $array1.numElements(); $i++) {
| $processArray1
Expand Down Expand Up @@ -4396,9 +4402,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
val openHashSet = classOf[OpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"

def withArray2NullCheck(body: String): String =
if (right.dataType.asInstanceOf[ArrayType].containsNull) {
Expand Down Expand Up @@ -4474,8 +4479,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
|for (int $i = 0; $i < $array2.numElements(); $i++) {
| $writeArray2ToHashSet
|}
|$arrayBuilderClass $builder =
| ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag);
|$arrayBuilderClass $builder = new $arrayBuilderClass();
|int $size = 0;
|for (int $i = 0; $i < $array1.numElements(); $i++) {
| $processArray1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1679,26 +1679,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
val df6 = Seq((null, null)).toDF("a", "b")
intercept[AnalysisException] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also check the error message?

df6.select(array_intersect($"a", $"b"))
}
}.getMessage.contains("data type mismatch")
intercept[AnalysisException] {
df6.selectExpr("array_intersect(a, b)")
}
}.getMessage.contains("data type mismatch")

val df7 = Seq((Array(1), Array("a"))).toDF("a", "b")
intercept[AnalysisException] {
df7.select(array_intersect($"a", $"b"))
}
}.getMessage.contains("data type mismatch")
intercept[AnalysisException] {
df7.selectExpr("array_intersect(a, b)")
}
}.getMessage.contains("data type mismatch")

val df8 = Seq((null, Array("a"))).toDF("a", "b")
intercept[AnalysisException] {
df8.select(array_intersect($"a", $"b"))
}
}.getMessage.contains("data type mismatch")
intercept[AnalysisException] {
df8.selectExpr("array_intersect(a, b)")
}
}.getMessage.contains("data type mismatch")
}

test("transform function - array for primitive type not containing null") {
Expand Down