Skip to content

Commit 1887fa2

Browse files
viiryadavies
authored andcommitted
[SPARK-11743] [SQL] Add UserDefinedType support to RowEncoder
JIRA: https://issues.apache.org/jira/browse/SPARK-11743 RowEncoder doesn't support UserDefinedType now. We should add the support for it. Author: Liang-Chi Hsieh <[email protected]> Closes #9712 from viirya/rowencoder-udt. (cherry picked from commit b0c3fd3) Signed-off-by: Davies Liu <[email protected]>
1 parent 949c9b7 commit 1887fa2

File tree

4 files changed

+139
-29
lines changed

4 files changed

+139
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ trait Row extends Serializable {
152152
* BinaryType -> byte array
153153
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
154154
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
155-
* StructType -> org.apache.spark.sql.Row
155+
* StructType -> org.apache.spark.sql.Row (or Product)
156156
* }}}
157157
*/
158158
def apply(i: Int): Any = get(i)
@@ -177,7 +177,7 @@ trait Row extends Serializable {
177177
* BinaryType -> byte array
178178
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
179179
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
180-
* StructType -> org.apache.spark.sql.Row
180+
* StructType -> org.apache.spark.sql.Row (or Product)
181181
* }}}
182182
*/
183183
def get(i: Int): Any
@@ -306,7 +306,15 @@ trait Row extends Serializable {
306306
*
307307
* @throws ClassCastException when data type does not match.
308308
*/
309-
def getStruct(i: Int): Row = getAs[Row](i)
309+
def getStruct(i: Int): Row = {
310+
// Product and Row both are recoginized as StructType in a Row
311+
val t = get(i)
312+
if (t.isInstanceOf[Product]) {
313+
Row.fromTuple(t.asInstanceOf[Product])
314+
} else {
315+
t.asInstanceOf[Row]
316+
}
317+
}
310318

311319
/**
312320
* Returns the value at position i.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ object RowEncoder {
5050
case BooleanType | ByteType | ShortType | IntegerType | LongType |
5151
FloatType | DoubleType | BinaryType => inputObject
5252

53+
case udt: UserDefinedType[_] =>
54+
val obj = NewInstance(
55+
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
56+
Nil,
57+
false,
58+
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
59+
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
60+
5361
case TimestampType =>
5462
StaticInvoke(
5563
DateTimeUtils,
@@ -109,11 +117,16 @@ object RowEncoder {
109117

110118
case StructType(fields) =>
111119
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
120+
val method = if (f.dataType.isInstanceOf[StructType]) {
121+
"getStruct"
122+
} else {
123+
"get"
124+
}
112125
If(
113126
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
114127
Literal.create(null, f.dataType),
115128
extractorsFor(
116-
Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil),
129+
Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
117130
f.dataType))
118131
}
119132
CreateStruct(convertedFields)
@@ -137,6 +150,7 @@ object RowEncoder {
137150
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
138151
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
139152
case _: StructType => ObjectType(classOf[Row])
153+
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
140154
}
141155

142156
private def constructorFor(schema: StructType): Expression = {
@@ -155,6 +169,14 @@ object RowEncoder {
155169
case BooleanType | ByteType | ShortType | IntegerType | LongType |
156170
FloatType | DoubleType | BinaryType => input
157171

172+
case udt: UserDefinedType[_] =>
173+
val obj = NewInstance(
174+
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
175+
Nil,
176+
false,
177+
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
178+
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
179+
158180
case TimestampType =>
159181
StaticInvoke(
160182
DateTimeUtils,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ case class Invoke(
113113
arguments: Seq[Expression] = Nil) extends Expression {
114114

115115
override def nullable: Boolean = true
116-
override def children: Seq[Expression] = targetObject :: Nil
116+
override def children: Seq[Expression] = arguments.+:(targetObject)
117117

118118
override def eval(input: InternalRow): Any =
119119
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@@ -343,33 +343,35 @@ case class MapObjects(
343343
private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
344344
private lazy val completeFunction = function(loopAttribute)
345345

346+
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
347+
case IntegerType => (i: String) => s".getInt($i)"
348+
case LongType => (i: String) => s".getLong($i)"
349+
case FloatType => (i: String) => s".getFloat($i)"
350+
case DoubleType => (i: String) => s".getDouble($i)"
351+
case ByteType => (i: String) => s".getByte($i)"
352+
case ShortType => (i: String) => s".getShort($i)"
353+
case BooleanType => (i: String) => s".getBoolean($i)"
354+
case StringType => (i: String) => s".getUTF8String($i)"
355+
case s: StructType => (i: String) => s".getStruct($i, ${s.size})"
356+
case a: ArrayType => (i: String) => s".getArray($i)"
357+
case _: MapType => (i: String) => s".getMap($i)"
358+
case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
359+
}
360+
346361
private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
347362
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
348363
(".size()", (i: String) => s".apply($i)", false)
349364
case ObjectType(cls) if cls.isArray =>
350365
(".length", (i: String) => s"[$i]", false)
351-
case ArrayType(s: StructType, _) =>
352-
(".numElements()", (i: String) => s".getStruct($i, ${s.size})", false)
353-
case ArrayType(a: ArrayType, _) =>
354-
(".numElements()", (i: String) => s".getArray($i)", true)
355-
case ArrayType(IntegerType, _) =>
356-
(".numElements()", (i: String) => s".getInt($i)", true)
357-
case ArrayType(LongType, _) =>
358-
(".numElements()", (i: String) => s".getLong($i)", true)
359-
case ArrayType(FloatType, _) =>
360-
(".numElements()", (i: String) => s".getFloat($i)", true)
361-
case ArrayType(DoubleType, _) =>
362-
(".numElements()", (i: String) => s".getDouble($i)", true)
363-
case ArrayType(ByteType, _) =>
364-
(".numElements()", (i: String) => s".getByte($i)", true)
365-
case ArrayType(ShortType, _) =>
366-
(".numElements()", (i: String) => s".getShort($i)", true)
367-
case ArrayType(BooleanType, _) =>
368-
(".numElements()", (i: String) => s".getBoolean($i)", true)
369-
case ArrayType(StringType, _) =>
370-
(".numElements()", (i: String) => s".getUTF8String($i)", false)
371-
case ArrayType(_: MapType, _) =>
372-
(".numElements()", (i: String) => s".getMap($i)", false)
366+
case ArrayType(t, _) =>
367+
val (sqlType, primitiveElement) = t match {
368+
case m: MapType => (m, false)
369+
case s: StructType => (s, false)
370+
case s: StringType => (s, false)
371+
case udt: UserDefinedType[_] => (udt.sqlType, false)
372+
case o => (o, true)
373+
}
374+
(".numElements()", itemAccessorMethod(sqlType), primitiveElement)
373375
}
374376

375377
override def nullable: Boolean = true

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.{RandomDataGenerator, Row}
22+
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
2223
import org.apache.spark.sql.types._
2324
import org.apache.spark.unsafe.types.UTF8String
2425

26+
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
27+
class ExamplePoint(val x: Double, val y: Double) extends Serializable {
28+
override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt
29+
override def equals(that: Any): Boolean = {
30+
if (that.isInstanceOf[ExamplePoint]) {
31+
val e = that.asInstanceOf[ExamplePoint]
32+
(this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) &&
33+
(this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity))
34+
} else {
35+
false
36+
}
37+
}
38+
}
39+
40+
/**
41+
* User-defined type for [[ExamplePoint]].
42+
*/
43+
class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
44+
45+
override def sqlType: DataType = ArrayType(DoubleType, false)
46+
47+
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
48+
49+
override def serialize(obj: Any): GenericArrayData = {
50+
obj match {
51+
case p: ExamplePoint =>
52+
val output = new Array[Any](2)
53+
output(0) = p.x
54+
output(1) = p.y
55+
new GenericArrayData(output)
56+
}
57+
}
58+
59+
override def deserialize(datum: Any): ExamplePoint = {
60+
datum match {
61+
case values: ArrayData =>
62+
new ExamplePoint(values.getDouble(0), values.getDouble(1))
63+
}
64+
}
65+
66+
override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
67+
68+
private[spark] override def asNullable: ExamplePointUDT = this
69+
}
70+
2571
class RowEncoderSuite extends SparkFunSuite {
2672

2773
private val structOfString = new StructType().add("str", StringType)
74+
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
2875
private val arrayOfString = ArrayType(StringType)
2976
private val mapOfString = MapType(StringType, StringType)
77+
private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
3078

3179
encodeDecodeTest(
3280
new StructType()
@@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite {
4189
.add("string", StringType)
4290
.add("binary", BinaryType)
4391
.add("date", DateType)
44-
.add("timestamp", TimestampType))
92+
.add("timestamp", TimestampType)
93+
.add("udt", new ExamplePointUDT, false))
4594

4695
encodeDecodeTest(
4796
new StructType()
@@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite {
68117
.add("structOfArray", new StructType().add("array", arrayOfString))
69118
.add("structOfMap", new StructType().add("map", mapOfString))
70119
.add("structOfArrayAndMap",
71-
new StructType().add("array", arrayOfString).add("map", mapOfString)))
120+
new StructType().add("array", arrayOfString).add("map", mapOfString))
121+
.add("structOfUDT", structOfUDT))
122+
123+
test(s"encode/decode: arrayOfUDT") {
124+
val schema = new StructType()
125+
.add("arrayOfUDT", arrayOfUDT)
126+
127+
val encoder = RowEncoder(schema)
128+
129+
val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4)))
130+
val row = encoder.toRow(input)
131+
val convertedBack = encoder.fromRow(row)
132+
assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0))
133+
}
134+
135+
test(s"encode/decode: Product") {
136+
val schema = new StructType()
137+
.add("structAsProduct",
138+
new StructType()
139+
.add("int", IntegerType)
140+
.add("string", StringType)
141+
.add("double", DoubleType))
142+
143+
val encoder = RowEncoder(schema)
144+
145+
val input: Row = Row((100, "test", 0.123))
146+
val row = encoder.toRow(input)
147+
val convertedBack = encoder.fromRow(row)
148+
assert(input.getStruct(0) == convertedBack.getStruct(0))
149+
}
72150

73151
private def encodeDecodeTest(schema: StructType): Unit = {
74152
test(s"encode/decode: ${schema.simpleString}") {

0 commit comments

Comments
 (0)