Skip to content

Commit e580bb0

Browse files
arayliancheng
authored andcommitted
[SPARK-18717][SQL] Make code generation for Scala Map work with immutable.Map also
## What changes were proposed in this pull request? Fixes compile errors in generated code when user has case class with a `scala.collections.immutable.Map` instead of a `scala.collections.Map`. Since ArrayBasedMapData.toScalaMap returns the immutable version we can make it work with both. ## How was this patch tested? Additional unit tests. Author: Andrew Ray <ray.andrew@gmail.com> Closes #16161 from aray/fix-map-codegen. (cherry picked from commit 46d30ac) Signed-off-by: Cheng Lian <lian@databricks.com>
1 parent 7b5ea00 commit e580bb0

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ object ScalaReflection extends ScalaReflection {
342342

343343
StaticInvoke(
344344
ArrayBasedMapData.getClass,
345-
ObjectType(classOf[Map[_, _]]),
345+
ObjectType(classOf[scala.collection.immutable.Map[_, _]]),
346346
"toScalaMap",
347347
keyData :: valueData :: Nil)
348348

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,8 +1063,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
10631063
// sizeInBytes is 2404280404, before the fix, it overflows to a negative number
10641064
assert(sizeInBytes > 0)
10651065
}
1066+
1067+
test("SPARK-18717: code generation works for both scala.collection.Map" +
1068+
" and scala.collection.imutable.Map") {
1069+
val ds = Seq(WithImmutableMap("hi", Map(42L -> "foo"))).toDS
1070+
checkDataset(ds.map(t => t), WithImmutableMap("hi", Map(42L -> "foo")))
1071+
1072+
val ds2 = Seq(WithMap("hi", Map(42L -> "foo"))).toDS
1073+
checkDataset(ds2.map(t => t), WithMap("hi", Map(42L -> "foo")))
1074+
}
10661075
}
10671076

1077+
case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
1078+
case class WithMap(id: String, map_test: scala.collection.Map[Long, String])
1079+
10681080
case class Generic[T](id: T, value: Double)
10691081

10701082
case class OtherTuple(_1: String, _2: Int)

0 commit comments

Comments
 (0)