Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add arrayOfUDT.
  • Loading branch information
viirya committed Nov 14, 2015
commit db644fb2be3a3fd32f7f62993575dc6d8cef594e
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ case class Invoke(
arguments: Seq[Expression] = Nil) extends Expression {

override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject :: Nil
override def children: Seq[Expression] = arguments.+:(targetObject)

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
Expand Down Expand Up @@ -343,33 +343,50 @@ case class MapObjects(
private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
private lazy val completeFunction = function(loopAttribute)

private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
case IntegerType => (i: String) => s".getInt($i)"
case LongType => (i: String) => s".getLong($i)"
case FloatType => (i: String) => s".getFloat($i)"
case DoubleType => (i: String) => s".getDouble($i)"
case ByteType => (i: String) => s".getByte($i)"
case ShortType => (i: String) => s".getShort($i)"
case BooleanType => (i: String) => s".getBoolean($i)"
case StringType => (i: String) => s".getUTF8String($i)"
case s: StructType => (i: String) => s".getStruct($i, ${s.size})"
case a: ArrayType => (i: String) => s".getArray($i)"
case _: MapType => (i: String) => s".getMap($i)"
case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
}

private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
(".size()", (i: String) => s".apply($i)", false)
case ObjectType(cls) if cls.isArray =>
(".length", (i: String) => s"[$i]", false)
case ArrayType(s: StructType, _) =>
(".numElements()", (i: String) => s".getStruct($i, ${s.size})", false)
(".numElements()", itemAccessorMethod(s), false)
case ArrayType(a: ArrayType, _) =>
(".numElements()", (i: String) => s".getArray($i)", true)
(".numElements()", itemAccessorMethod(a), true)
case ArrayType(IntegerType, _) =>
(".numElements()", (i: String) => s".getInt($i)", true)
(".numElements()", itemAccessorMethod(IntegerType), true)
case ArrayType(LongType, _) =>
(".numElements()", (i: String) => s".getLong($i)", true)
(".numElements()", itemAccessorMethod(LongType), true)
case ArrayType(FloatType, _) =>
(".numElements()", (i: String) => s".getFloat($i)", true)
(".numElements()", itemAccessorMethod(FloatType), true)
case ArrayType(DoubleType, _) =>
(".numElements()", (i: String) => s".getDouble($i)", true)
(".numElements()", itemAccessorMethod(DoubleType), true)
case ArrayType(ByteType, _) =>
(".numElements()", (i: String) => s".getByte($i)", true)
(".numElements()", itemAccessorMethod(ByteType), true)
case ArrayType(ShortType, _) =>
(".numElements()", (i: String) => s".getShort($i)", true)
(".numElements()", itemAccessorMethod(ShortType), true)
case ArrayType(BooleanType, _) =>
(".numElements()", (i: String) => s".getBoolean($i)", true)
(".numElements()", itemAccessorMethod(BooleanType), true)
case ArrayType(StringType, _) =>
(".numElements()", (i: String) => s".getUTF8String($i)", false)
case ArrayType(_: MapType, _) =>
(".numElements()", (i: String) => s".getMap($i)", false)
(".numElements()", itemAccessorMethod(StringType), false)
case ArrayType(m: MapType, _) =>
(".numElements()", itemAccessorMethod(m), false)
case ArrayType(udt: UserDefinedType[_], _) =>
(".numElements()", itemAccessorMethod(udt.sqlType), false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we merge these branches together?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. updated.

}

override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class RowEncoderSuite extends SparkFunSuite {
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)
private val mapOfString = MapType(StringType, StringType)
private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)

encodeDecodeTest(
new StructType()
Expand Down Expand Up @@ -119,6 +120,18 @@ class RowEncoderSuite extends SparkFunSuite {
new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT))

test(s"encode/decode: arrayOfUDT") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why we put this test here instead of adding arrayOfUDT type in encodeDecodeTest like structOfUDT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I moved it in #10538.

val schema = new StructType()
.add("arrayOfUDT", arrayOfUDT)

val encoder = RowEncoder(schema)

val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4)))
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0))
}

test(s"encode/decode: Product") {
val schema = new StructType()
.add("structAsProduct",
Expand Down