Skip to content
Prev Previous commit
Next Next commit
create UnsafeArrayData from a primitive array in RowEncoder.serializeFor
  • Loading branch information
kiszk committed Nov 2, 2016
commit edfbce38c65dbb430f3e23aa0edbd0fd889958f0
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,33 @@ object RowEncoder {
"fromString",
inputObject :: Nil)

case t @ ArrayType(et, _) => et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
// TODO: validate input type for primitive array.
NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
case _ => MapObjects(
element => serializerFor(ValidateExternalType(element, et), et),
inputObject,
ObjectType(classOf[Object]))
}
case t @ ArrayType(et, cn) =>
val cls = inputObject.dataType.asInstanceOf[ObjectType].cls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we use the cls?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, removed

et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is going on here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my mistake

Copy link
Contributor

@cloud-fan cloud-fan Nov 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do the same thing here? i.e. special handling primitive array. I know we don't have the class information here, bu we can do it in the runtime:

object ArrayData {
  def toArrayData(input: Any): ArrayData = input match {
    case a: Array[Boolean] => UnsafeArrayData.fromPrimitive(a)
    ...
    case other => new GenericArrayData(other)
  }
}

Then we just use StaticInvoke here to call this method

Copy link
Member Author

@kiszk kiszk Nov 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. It worked well.

if !cn && (
cls.isAssignableFrom(classOf[Array[Boolean]]) ||
cls.isAssignableFrom(classOf[Array[Byte]]) ||
cls.isAssignableFrom(classOf[Array[Short]]) ||
cls.isAssignableFrom(classOf[Array[Int]]) ||
cls.isAssignableFrom(classOf[Array[Long]]) ||
cls.isAssignableFrom(classOf[Array[Float]]) ||
cls.isAssignableFrom(classOf[Array[Double]])) =>
StaticInvoke(
classOf[UnsafeArrayData],
ArrayType(et, false),
"fromPrimitiveArray",
inputObject :: Nil)
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
case _ => MapObjects(
element => serializerFor(ValidateExternalType(element, et), et),
inputObject,
ObjectType(classOf[Object]))
}

case t @ MapType(kt, vt, valueNullable) =>
val keys =
Expand Down Expand Up @@ -193,6 +208,17 @@ object RowEncoder {
// as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object])
// In order to support both Array and Seq in external row, we make this as java.lang.Object.
case a @ ArrayType(et, cn) if !cn =>
et match {
case BooleanType => ObjectType(classOf[Array[Boolean]])
case ByteType => ObjectType(classOf[Array[Byte]])
case ShortType => ObjectType(classOf[Array[Short]])
case IntegerType => ObjectType(classOf[Array[Int]])
case LongType => ObjectType(classOf[Array[Long]])
case FloatType => ObjectType(classOf[Array[Float]])
case DoubleType => ObjectType(classOf[Array[Double]])
case _ => ObjectType(classOf[java.lang.Object])
}
case _: ArrayType => ObjectType(classOf[java.lang.Object])
case _ => externalDataTypeFor(dt)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,12 @@ case class ValidateExternalType(child: Expression, expected: DataType)
case _: DecimalType =>
Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
case _: ArrayType =>
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
case a @ ArrayType(et, cn) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: why this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if (!cn && ctx.isPrimitiveType(et)) {
s"$obj instanceof ${ctx.javaType(et)}[]"
} else {
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
}
case _ =>
s"$obj instanceof ${ctx.boxedType(dataType)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,32 @@ class RowEncoderSuite extends SparkFunSuite {
assert(encoder.serializer.head.nullable == false)
}

test("RowEncoder should support a primitive array") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: RowEncoder should support primitive arrays

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val schema = new StructType()
.add("booleanPrimitiveArray", ArrayType(BooleanType, false))
.add("bytePrimitiveArray", ArrayType(ByteType, false))
.add("shortPrimitiveArray", ArrayType(ShortType, false))
.add("intPrimitiveArray", ArrayType(IntegerType, false))
.add("longPrimitiveArray", ArrayType(LongType, false))
.add("floatPrimitiveArray", ArrayType(FloatType, false))
.add("doublePrimitiveArray", ArrayType(DoubleType, false))
val encoder = RowEncoder(schema).resolveAndBind()
val input = Seq(
Array(true, false),
Array(1.toByte, 64.toByte, Byte.MaxValue),
Array(1.toShort, 255.toShort, Short.MaxValue),
Array(1, 10000, Int.MaxValue),
Array(1.toLong, 1000000.toLong, Long.MaxValue),
Array(1.1.toFloat, 123.456.toFloat, Float.MaxValue),
Array(11.1111, 123456.7890123, Double.MaxValue)
)
val row = encoder.toRow(Row.fromSeq(input))
val convertedBack = encoder.fromRow(row)
input.zipWithIndex.map { case (array, index) =>
assert(convertedBack.getSeq(index) === array)
}
}

test("RowEncoder should support array as the external type for ArrayType") {
val schema = new StructType()
.add("array", ArrayType(IntegerType))
Expand Down