Skip to content
Closed
Prev Previous commit
Next Next commit
specialize codegen for primitive types
  • Loading branch information
mgaido91 committed Apr 27, 2018
commit dc6cb60f5bee56473d65e50b500dea694c28d2b3
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ case class Slice(x: Expression, start: Expression, length: Expression)

override def children: Seq[Expression] = Seq(x, start, length)

lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType

override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
val startInt = startVal.asInstanceOf[Int]
val lengthInt = lengthVal.asInstanceOf[Int]
Expand All @@ -422,17 +424,12 @@ case class Slice(x: Expression, start: Expression, length: Expression)
if (startIndex < 0 || startIndex >= arr.numElements()) {
return new GenericArrayData(Array.empty[AnyRef])
}
val elementType = x.dataType.asInstanceOf[ArrayType].elementType
val data = arr.toSeq[AnyRef](elementType)
new GenericArrayData(data.slice(startIndex, startIndex + lengthInt))
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementType = x.dataType.asInstanceOf[ArrayType].elementType
nullSafeCodeGen(ctx, ev, (x, start, length) => {
val arrayClass = classOf[GenericArrayData].getName
val values = ctx.freshName("values")
val i = ctx.freshName("i")
val startIdx = ctx.freshName("startIdx")
val resLength = ctx.freshName("resLength")
val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false)
Expand All @@ -456,18 +453,60 @@ case class Slice(x: Expression, start: Expression, length: Expression)
|} else {
| $resLength = $length;
|}
|${genCodeForResult(ctx, ev, x, startIdx, resLength)}
""".stripMargin
})
}

def genCodeForResult(
ctx: CodegenContext,
ev: ExprCode,
inputArray: String,
startIdx: String,
resLength: String): String = {
val values = ctx.freshName("values")
val i = ctx.freshName("i")
val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx")
if (!CodeGenerator.isPrimitiveType(elementType)) {
val arrayClass = classOf[GenericArrayData].getName
s"""
|Object[] $values;
|if ($startIdx < 0 || $startIdx >= $x.numElements()) {
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
| $values = new Object[0];
|} else {
| $values = new Object[$resLength];
| for (int $i = 0; $i < $resLength; $i ++) {
| $values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")};
| $values[$i] = $getValue;
| }
|}
|${ev.value} = new $arrayClass($values);
""".stripMargin
})
} else {
val sizeInBytes = ctx.freshName("sizeInBytes")
val bytesArray = ctx.freshName("bytesArray")
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
| $resLength = 0;
|}
|${CodeGenerator.JAVA_INT} $sizeInBytes =
| UnsafeArrayData.calculateHeaderPortionInBytes($resLength) +
| ${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord(
| ${elementType.defaultSize} * $resLength);
|byte[] $bytesArray = new byte[$sizeInBytes];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if sizeInBytes is larger than Integer.MAX_VALUE? For example, 0x7000_0000 long elements. In this case, GenericArrayData or long[] can hold these elements. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other places (eg Concat) in such a case we just throw a runtime exception. What about following the same pattern here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not even sure we have to add such a check actually, since here we can only reduce the size of an already existing array... Anyway probably it is ok to add an additional sanity check. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious about the following two cases.

  1. In UnsafeArray, long[] may be used. Its size is 0x8000_0000 * 4. On the other hand, the size is the allocated byte[] is up to 0x8000_0000.
  2. If GenericArray, which includes a lot of (e.g. 0x7F00_0000) Long or Double elements, is passed to this operation, the expected allocation size is more than 0x8000_0000.

While these cases reduce the size of an existing array, does the result array fit into byte[]? WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the same check which is performed in Concat and Flatten. If we want to support also larger arrays of primitives, we probably best have another PR which address the issue on all the functions affected (this one, Concat and Flatten), especially considering that the issue is much more likely to happen in the other two cases. Do you agree?

|UnsafeArrayData $values = new UnsafeArrayData();
|Platform.putLong($bytesArray, ${Platform.BYTE_ARRAY_OFFSET}, $resLength);
|$values.pointTo($bytesArray, ${Platform.BYTE_ARRAY_OFFSET}, $sizeInBytes);
|for (int $i = 0; $i < $resLength; $i ++) {
| if ($inputArray.isNullAt($i + $startIdx)) {
| $values.setNullAt($i);
| } else {
| $values.set$primitiveValueTypeName($i, $getValue);
| }
|}
|${ev.value} = $values;
""".stripMargin
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))
val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType))
val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType))

checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2))
checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5))
Expand All @@ -120,6 +121,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)),
"Unexpected value for start")
checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int])
checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String])
checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null)
checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null)
checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)),
Expand All @@ -128,6 +130,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b"))
checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null))
checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int])
checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String])
checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4))
}

test("Array Min") {
Expand Down