Skip to content

Commit 7515367

Browse files
ueshinrxin
authored andcommitted
[SPARK-1845] [SQL] Use AllScalaRegistrar for SparkSqlSerializer to register serializers of ...
...Scala collections. When I execute `orderBy` or `limit` for `SchemaRDD` including `ArrayType` or `MapType`, `SparkSqlSerializer` throws the following exception: ``` com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.$colon$colon ``` or ``` com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.Vector ``` or ``` com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.HashMap$HashTrieMap ``` and so on. This is because registrations of serializers for each concrete collections are missing in `SparkSqlSerializer`. I believe it should use `AllScalaRegistrar`. `AllScalaRegistrar` covers a lot of serializers for concrete classes of `Seq`, `Map` for `ArrayType`, `MapType`. Author: Takuya UESHIN <[email protected]> Closes #790 from ueshin/issues/SPARK-1845 and squashes the following commits: d1ed992 [Takuya UESHIN] Use AllScalaRegistrar for SparkSqlSerializer to register serializers of Scala collections. (cherry picked from commit db8cc6f) Signed-off-by: Reynold Xin <[email protected]>
1 parent aa5f989 commit 7515367

File tree

4 files changed

+66
-26
lines changed

4 files changed

+66
-26
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
2424
import com.clearspring.analytics.stream.cardinality.HyperLogLog
2525
import com.esotericsoftware.kryo.io.{Input, Output}
2626
import com.esotericsoftware.kryo.{Serializer, Kryo}
27+
import com.twitter.chill.AllScalaRegistrar
2728

2829
import org.apache.spark.{SparkEnv, SparkConf}
2930
import org.apache.spark.serializer.KryoSerializer
@@ -35,22 +36,14 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
3536
val kryo = new Kryo()
3637
kryo.setRegistrationRequired(false)
3738
kryo.register(classOf[MutablePair[_, _]])
38-
kryo.register(classOf[Array[Any]])
39-
// This is kinda hacky...
40-
kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
41-
kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
42-
kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
43-
kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
44-
kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
45-
kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
4639
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
4740
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
4841
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
4942
new HyperLogLogSerializer)
50-
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
5143
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
5244
kryo.setReferences(false)
5345
kryo.setClassLoader(Utils.getSparkClassLoader)
46+
new AllScalaRegistrar().apply(kryo)
5447
kryo
5548
}
5649
}
@@ -97,20 +90,3 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
9790
HyperLogLog.Builder.build(bytes)
9891
}
9992
}
100-
101-
/**
102-
* Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
103-
* them as `Array[(k,v)]`.
104-
*/
105-
private[sql] class MapSerializer extends Serializer[Map[_,_]] {
106-
def write(kryo: Kryo, output: Output, map: Map[_,_]) {
107-
kryo.writeObject(output, map.flatMap(e => Seq(e._1, e._2)).toArray)
108-
}
109-
110-
def read(kryo: Kryo, input: Input, tpe: Class[Map[_,_]]): Map[_,_] = {
111-
kryo.readObject(input, classOf[Array[Any]])
112-
.sliding(2,2)
113-
.map { case Array(k,v) => (k,v) }
114-
.toMap
115-
}
116-
}

sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,36 @@ class DslQuerySuite extends QueryTest {
6969
checkAnswer(
7070
testData2.orderBy('a.desc, 'b.asc),
7171
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
72+
73+
checkAnswer(
74+
arrayData.orderBy(GetItem('data, 0).asc),
75+
arrayData.collect().sortBy(_.data(0)).toSeq)
76+
77+
checkAnswer(
78+
arrayData.orderBy(GetItem('data, 0).desc),
79+
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
80+
81+
checkAnswer(
82+
mapData.orderBy(GetItem('data, 1).asc),
83+
mapData.collect().sortBy(_.data(1)).toSeq)
84+
85+
checkAnswer(
86+
mapData.orderBy(GetItem('data, 1).desc),
87+
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
7288
}
7389

7490
test("limit") {
7591
checkAnswer(
7692
testData.limit(10),
7793
testData.take(10).toSeq)
94+
95+
checkAnswer(
96+
arrayData.limit(1),
97+
arrayData.take(1).toSeq)
98+
99+
checkAnswer(
100+
mapData.limit(1),
101+
mapData.take(1).toSeq)
78102
}
79103

80104
test("average") {

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,36 @@ class SQLQuerySuite extends QueryTest {
8585
checkAnswer(
8686
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
8787
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
88+
89+
checkAnswer(
90+
sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
91+
arrayData.collect().sortBy(_.data(0)).toSeq)
92+
93+
checkAnswer(
94+
sql("SELECT * FROM arrayData ORDER BY data[0] DESC"),
95+
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
96+
97+
checkAnswer(
98+
sql("SELECT * FROM mapData ORDER BY data[1] ASC"),
99+
mapData.collect().sortBy(_.data(1)).toSeq)
100+
101+
checkAnswer(
102+
sql("SELECT * FROM mapData ORDER BY data[1] DESC"),
103+
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
104+
}
105+
106+
test("limit") {
107+
checkAnswer(
108+
sql("SELECT * FROM testData LIMIT 10"),
109+
testData.take(10).toSeq)
110+
111+
checkAnswer(
112+
sql("SELECT * FROM arrayData LIMIT 1"),
113+
arrayData.collect().take(1).toSeq)
114+
115+
checkAnswer(
116+
sql("SELECT * FROM mapData LIMIT 1"),
117+
mapData.collect().take(1).toSeq)
88118
}
89119

90120
test("average") {

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ object TestData {
7474
ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
7575
arrayData.registerAsTable("arrayData")
7676

77+
case class MapData(data: Map[Int, String])
78+
val mapData =
79+
TestSQLContext.sparkContext.parallelize(
80+
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
81+
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
82+
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
83+
MapData(Map(1 -> "a4", 2 -> "b4")) ::
84+
MapData(Map(1 -> "a5")) :: Nil)
85+
mapData.registerAsTable("mapData")
86+
7787
case class StringData(s: String)
7888
val repeatedData =
7989
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))

0 commit comments

Comments
 (0)