Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address review comments
  • Loading branch information
kiszk committed Jul 27, 2018
commit 02b809bf11fe225ae5de16cbe8e2cda37d0d56b4
Original file line number Diff line number Diff line change
Expand Up @@ -4055,237 +4055,137 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
pos
}

override def nullSafeEval(input1: Any, input2: Any): Any = {
val array1 = input1.asInstanceOf[ArrayData]
val array2 = input2.asInstanceOf[ArrayData]

if (elementTypeSupportEquals) {
elementType match {
case IntegerType =>
// avoid boxing of primitive int array elements
// calculate result array size
hsInt = new OpenHashSet[Int]
val elements = evalIntLongPrimitiveType(array1, array2, null, false)
// allocate result array
hsInt = new OpenHashSet[Int]
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
IntegerType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
UnsafeArrayData.forPrimitiveArray(
Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
}
// assign elements into the result array
evalIntLongPrimitiveType(array1, array2, resultArray, false)
resultArray
case LongType =>
// avoid boxing of primitive long array elements
// calculate result array size
hsLong = new OpenHashSet[Long]
val elements = evalIntLongPrimitiveType(array1, array2, null, true)
// allocate result array
hsLong = new OpenHashSet[Long]
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
LongType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
UnsafeArrayData.forPrimitiveArray(
Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
}
// assign elements into the result array
evalIntLongPrimitiveType(array1, array2, resultArray, true)
resultArray
case _ =>
val hs = new OpenHashSet[Any]
var notFoundNullElement = true
var i = 0
while (i < array2.numElements()) {
if (array2.isNullAt(i)) {
notFoundNullElement = false
val exceptEquals: (ArrayData, ArrayData) => ArrayData = {
(array1: ArrayData, array2: ArrayData) =>
if (elementTypeSupportEquals) {
elementType match {
case IntegerType =>
// avoid boxing of primitive int array elements
// calculate result array size
hsInt = new OpenHashSet[Int]
val elements = evalIntLongPrimitiveType(array1, array2, null, false)
// allocate result array
hsInt = new OpenHashSet[Int]
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
IntegerType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
val elem = array2.get(i, elementType)
hs.add(elem)
UnsafeArrayData.forPrimitiveArray(
Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
}
i += 1
}
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
i = 0
while (i < array1.numElements()) {
if (array1.isNullAt(i)) {
if (notFoundNullElement) {
arrayBuffer += null
// assign elements into the result array
evalIntLongPrimitiveType(array1, array2, resultArray, false)
resultArray
case LongType =>
// avoid boxing of primitive long array elements
// calculate result array size
hsLong = new OpenHashSet[Long]
val elements = evalIntLongPrimitiveType(array1, array2, null, true)
// allocate result array
hsLong = new OpenHashSet[Long]
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
LongType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
UnsafeArrayData.forPrimitiveArray(
Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
}
// assign elements into the result array
evalIntLongPrimitiveType(array1, array2, resultArray, true)
resultArray
case _ =>
val hs = new OpenHashSet[Any]
var notFoundNullElement = true
var i = 0
while (i < array2.numElements()) {
if (array2.isNullAt(i)) {
notFoundNullElement = false
} else {
val elem = array2.get(i, elementType)
hs.add(elem)
}
i += 1
}
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
i = 0
while (i < array1.numElements()) {
if (array1.isNullAt(i)) {
if (notFoundNullElement) {
arrayBuffer += null
notFoundNullElement = false
}
} else {
val elem = array1.get(i, elementType)
if (!hs.contains(elem)) {
arrayBuffer += elem
hs.add(elem)
}
}
i += 1
}
new GenericArrayData(arrayBuffer)
}
} else {
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var scannedNullElements = false
var i = 0
while (i < array1.numElements()) {
var found = false
val elem1 = array1.get(i, elementType)
if (elem1 == null) {
if (!scannedNullElements) {
var j = 0
while (!found && j < array2.numElements()) {
found = array2.isNullAt(j)
j += 1
}
// array2 is scanned only once for null element
scannedNullElements = true
} else {
val elem = array1.get(i, elementType)
if (!hs.contains(elem)) {
arrayBuffer += elem
hs.add(elem)
found = true
}
} else {
var j = 0
while (!found && j < array2.numElements()) {
val elem2 = array2.get(j, elementType)
if (elem2 != null) {
found = ordering.equiv(elem1, elem2)
}
j += 1
}
if (!found) {
// check whether elem1 is already stored in arrayBuffer
var k = 0
while (!found && k < arrayBuffer.size) {
val va = arrayBuffer(k)
found = (va != null) && ordering.equiv(va, elem1)
k += 1
}
}
i += 1
}
new GenericArrayData(arrayBuffer)
if (!found) {
arrayBuffer += elem1
}
i += 1
}
new GenericArrayData(arrayBuffer)
}
} else {
ArrayExcept.exceptOrdering(array1, array2, elementType, ordering)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val i = ctx.freshName("i")
val pos = ctx.freshName("pos")
val value = ctx.freshName("value")
val size = ctx.freshName("size")
val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
if (elementTypeSupportEquals) {
elementType match {
case ByteType | ShortType | IntegerType | LongType =>
val ptName = CodeGenerator.primitiveTypeName(elementType)
val unsafeArray = ctx.freshName("unsafeArray")
(if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
if (elementType == LongType) "Long" else "Int",
s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
if (elementType == LongType) "(long)" else "(int)",
s"""
|${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
|${ev.value} = $unsafeArray;
""".stripMargin)
case _ =>
val genericArrayData = classOf[GenericArrayData].getName
val et = ctx.addReferenceObj("elementType", elementType)
("", "Object",
s"get($i, $et)", s"update($pos, $value)", "Object", "",
s"${ev.value} = new $genericArrayData(new Object[$size]);")
}
} else {
("", "", "", "", "", "", "")
}
override def nullSafeEval(input1: Any, input2: Any): Any = {
val array1 = input1.asInstanceOf[ArrayData]
val array2 = input2.asInstanceOf[ArrayData]

exceptEquals(array1, array2)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Copy link
Contributor

@cloud-fan cloud-fan Jul 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think codegen can provide much value, given the complexity of this function. We do codegen mostly for supporting whole-stage-codegen. How about

val expr = ctx.addReference("arrayExceptExpr", this)
nullSafeCodeGen(ctx, ev, (array1, array2) => {
  s"$expr.nullSafeEval($array1, $array2)"
} 
"""

cc @rednaxelafx @rxin

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seem to make sense. Codegen can only increase the number of supporting types (i.e. byte/short).

val arrayData = classOf[ArrayData].getName
val expr = ctx.addReferenceObj("arrayExceptExpr", this)
nullSafeCodeGen(ctx, ev, (array1, array2) => {
if (openHashElementType != "") {
// Here, we ensure elementTypeSupportEquals is true
val notFoundNullElement = ctx.freshName("notFoundNullElement")
val openHashSet = classOf[OpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
val hs = ctx.freshName("hs")
val arrayData = classOf[ArrayData].getName
val arrays = ctx.freshName("arrays")
val array = ctx.freshName("array")
val arrayDataIdx = ctx.freshName("arrayDataIdx")
s"""
|$openHashSet $hs = new $openHashSet$postFix($classTag);
|boolean $notFoundNullElement = true;
|int $size = 0;
|for (int $i = 0; $i < $array2.numElements(); $i++) {
| if ($array2.isNullAt($i)) {
| $notFoundNullElement = false;
| } else {
| $hs.add$postFix($array2.$getter);
| }
|}
|for (int $i = 0; $i < $array1.numElements(); $i++) {
| if ($array1.isNullAt($i)) {
| if ($notFoundNullElement) {
| $size++;
| $notFoundNullElement = false;
| }
| } else {
| $javaTypeName $value = $array1.$getter;
| if (!$hs.contains($castOp $value)) {
| $hs.add$postFix($value);
| $size++;
| }
| }
|}
|$arrayBuilder
|$hs = new $openHashSet$postFix($classTag);
|$notFoundNullElement = true;
|int $pos = 0;
|for (int $i = 0; $i < $array2.numElements(); $i++) {
| if ($array2.isNullAt($i)) {
| $notFoundNullElement = false;
| } else {
| $hs.add$postFix($array2.$getter);
| }
|}
|for (int $i = 0; $i < $array1.numElements(); $i++) {
| if ($array1.isNullAt($i)) {
| if ($notFoundNullElement) {
| ${ev.value}.setNullAt($pos++);
| $notFoundNullElement = false;
| }
| } else {
| $javaTypeName $value = $array1.$getter;
| if (!$hs.contains($castOp $value)) {
| $hs.add$postFix($value);
| ${ev.value}.$setter;
| $pos++;
| }
| }
|}
""".stripMargin
} else {
val arrayExcept = classOf[ArrayExcept].getName
val et = ctx.addReferenceObj("elementTypeIntersect", elementType)
val order = ctx.addReferenceObj("orderingIntersect", ordering)
val method = "exceptOrdering"
s"${ev.value} = $arrayExcept$$.MODULE$$.$method($array1, $array2, $et, $order);"
}
s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);"
})
}

override def prettyName: String = "array_except"
}

object ArrayExcept {
def exceptOrdering(
array1: ArrayData,
array2: ArrayData,
elementType: DataType,
ordering: Ordering[Any]): ArrayData = {
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var scannedNullElements = false
var i = 0
while (i < array1.numElements()) {
var found = false
var elem1 = array1.get(i, elementType)
if (array1.isNullAt(i)) {
if (!scannedNullElements) {
var j = 0
while (!found && j < array2.numElements()) {
found = array2.isNullAt(j)
j += 1
}
// array2 is scanned only once for null element
scannedNullElements = true
} else {
found = true
}
} else {
var j = 0
while (!found && j < array2.numElements()) {
if (!array2.isNullAt(j)) {
val elem2 = array2.get(j, elementType)
found = ordering.equiv(elem1, elem2)
}
j += 1
}
if (!found) {
// check whether elem1 is already stored in arrayBuffer
var k = 0
while (!found && k < arrayBuffer.size) {
val va = arrayBuffer(k)
found = (va != null) && ordering.equiv(va, elem1)
k += 1
}
}
}
if (!found) {
arrayBuffer += elem1
}
i += 1
}
new GenericArrayData(arrayBuffer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1512,10 +1512,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val a04 = Literal.create(Seq(1, 2, null, 4, 5, 1), ArrayType(IntegerType, true))
val a05 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType, true))
val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false))
val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false))
val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false))
val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false))
val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false))

val a10 = Literal.create(Seq(1L, 2L, 4L, 3L), ArrayType(LongType, false))
val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false))
Expand Down Expand Up @@ -1544,8 +1540,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayExcept(a04, a05), Seq(1, 5))
checkEvaluation(ArrayExcept(a04, a06), Seq(1, 2, null, 4, 5))
checkEvaluation(ArrayExcept(a06, a04), Seq.empty)
checkEvaluation(ArrayExcept(ab0, ab1), Seq[Byte](1, 3))
checkEvaluation(ArrayExcept(as0, as1), Seq[Short](1, 3))

checkEvaluation(ArrayExcept(a10, a11), Seq(1L, 3L))
checkEvaluation(ArrayExcept(a12, a11), Seq(1L))
Expand Down