Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Serializer, Kryo}
import com.twitter.chill.AllScalaRegistrar

import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.KryoSerializer
Expand All @@ -35,22 +36,14 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
val kryo = new Kryo()
kryo.setRegistrationRequired(false)
kryo.register(classOf[MutablePair[_, _]])
kryo.register(classOf[Array[Any]])
// This is kinda hacky...
kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
kryo.setClassLoader(Utils.getSparkClassLoader)
new AllScalaRegistrar().apply(kryo)
kryo
}
}
Expand Down Expand Up @@ -97,20 +90,3 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
HyperLogLog.Builder.build(bytes)
}
}

/**
* Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
* them as `Array[(k,v)]`.
*/
private[sql] class MapSerializer extends Serializer[Map[_,_]] {
def write(kryo: Kryo, output: Output, map: Map[_,_]) {
kryo.writeObject(output, map.flatMap(e => Seq(e._1, e._2)).toArray)
}

def read(kryo: Kryo, input: Input, tpe: Class[Map[_,_]]): Map[_,_] = {
kryo.readObject(input, classOf[Array[Any]])
.sliding(2,2)
.map { case Array(k,v) => (k,v) }
.toMap
}
}
24 changes: 24 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,36 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
testData2.orderBy('a.desc, 'b.asc),
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))

checkAnswer(
arrayData.orderBy(GetItem('data, 0).asc),
arrayData.collect().sortBy(_.data(0)).toSeq)

checkAnswer(
arrayData.orderBy(GetItem('data, 0).desc),
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)

checkAnswer(
mapData.orderBy(GetItem('data, 1).asc),
mapData.collect().sortBy(_.data(1)).toSeq)

checkAnswer(
mapData.orderBy(GetItem('data, 1).desc),
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}

test("limit") {
checkAnswer(
testData.limit(10),
testData.take(10).toSeq)

checkAnswer(
arrayData.limit(1),
arrayData.take(1).toSeq)

checkAnswer(
mapData.limit(1),
mapData.take(1).toSeq)
}

test("average") {
Expand Down
30 changes: 30 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,36 @@ class SQLQuerySuite extends QueryTest {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))

checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
arrayData.collect().sortBy(_.data(0)).toSeq)

checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] DESC"),
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)

checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] ASC"),
mapData.collect().sortBy(_.data(1)).toSeq)

checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] DESC"),
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}

test("limit") {
checkAnswer(
sql("SELECT * FROM testData LIMIT 10"),
testData.take(10).toSeq)

checkAnswer(
sql("SELECT * FROM arrayData LIMIT 1"),
arrayData.collect().take(1).toSeq)

checkAnswer(
sql("SELECT * FROM mapData LIMIT 1"),
mapData.collect().take(1).toSeq)
}

test("average") {
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ object TestData {
ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
arrayData.registerAsTable("arrayData")

case class MapData(data: Map[Int, String])
val mapData =
TestSQLContext.sparkContext.parallelize(
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
mapData.registerAsTable("mapData")

case class StringData(s: String)
val repeatedData =
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
Expand Down