diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 40074b36f6a9..912744eab6a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -143,8 +143,7 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr case _: ArrayType => expr - // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and - // it's not trivial to support by-name resolution for StructType inside MapType. + case _: MapType => expr case _ => UpCast(expr, expected, walkedTypePath) } @@ -163,8 +162,8 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. - val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, - walkedTypePath) + val input = upCastToExpectedType( + GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) val expr = deserializerFor(tpe, input, walkedTypePath) if (nullable) { @@ -350,10 +349,10 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - CatalystToExternalMap( + UnresolvedCatalystToExternalMap( + path, p => deserializerFor(keyType, p, walkedTypePath), p => deserializerFor(valueType, p, walkedTypePath), - path, mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -431,8 +430,8 @@ object ScalaReflection extends ScalaReflection { val walkedTypePath = s"""- root class: "$clsName"""" :: Nil // The input object to `ExpressionEncoder` is located at first column of an row. - val inputObject = BoundReference(0, dataTypeFor(tpe), - nullable = !tpe.typeSymbol.asClass.isPrimitive) + val isPrimitive = tpe.typeSymbol.asClass.isPrimitive + val inputObject = BoundReference(0, dataTypeFor(tpe), nullable = !isPrimitive) serializerFor(inputObject, tpe, walkedTypePath) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 63a07e3898f2..c2d22c5e7ce6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2384,14 +2384,23 @@ class Analyzer( case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { case ArrayType(et, cn) => - val expr = MapObjects(func, inputData, et, cn, cls) transformUp { + MapObjects(func, inputData, et, cn, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } - expr case other => throw new AnalysisException("need an array field but got " + other.catalogString) } + case u: UnresolvedCatalystToExternalMap if u.child.resolved => + u.child.dataType match { + case _: MapType => + CatalystToExternalMap(u) transformUp { + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + case other => + throw new AnalysisException("need a map field but got " + other.catalogString) + } } validateNestedTupleFields(result) result diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 29f6136a75ee..2c8e81ef17d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -119,10 +119,9 @@ object ExpressionEncoder { } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => - val getColumnsByOrdinals = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c } - .distinct - assert(getColumnsByOrdinals.size == 1, "object deserializer should have only one " + - s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}") + val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct + assert(getColExprs.size == 1, "object deserializer should have only one " + + s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") val input = GetStructField(GetColumnByOrdinal(0, schema), index) val newDeserializer = enc.objDeserializer.transformUp { @@ -216,7 +215,6 @@ case class ExpressionEncoder[T]( } nullSafeSerializer match { case If(_: IsNull, _, s: CreateNamedStruct) => s - case s: CreateNamedStruct => s case _ => throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index b6f9b4734e94..4fd36a47cef5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -30,14 +30,13 @@ import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -963,25 +962,32 @@ case class MapObjects private( } } +/** + * Similar to [[UnresolvedMapObjects]], this is a placeholder of [[CatalystToExternalMap]]. + * + * @param child An expression that when evaluated returns a map object. + * @param keyFunction The function applied on the key collection elements. + * @param valueFunction The function applied on the value collection elements. + * @param collClass The type of the resulting collection. + */ +case class UnresolvedCatalystToExternalMap( + child: Expression, + @transient keyFunction: Expression => Expression, + @transient valueFunction: Expression => Expression, + collClass: Class[_]) extends UnaryExpression with Unevaluable { + + override lazy val resolved = false + + override def dataType: DataType = ObjectType(collClass) +} + object CatalystToExternalMap { private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** - * Construct an instance of CatalystToExternalMap case class. - * - * @param keyFunction The function applied on the key collection elements. - * @param valueFunction The function applied on the value collection elements. - * @param inputData An expression that when evaluated returns a map object. - * @param collClass The type of the resulting collection. - */ - def apply( - keyFunction: Expression => Expression, - valueFunction: Expression => Expression, - inputData: Expression, - collClass: Class[_]): CatalystToExternalMap = { + def apply(u: UnresolvedCatalystToExternalMap): CatalystToExternalMap = { val id = curId.getAndIncrement() val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id" - val mapType = inputData.dataType.asInstanceOf[MapType] + val mapType = u.child.dataType.asInstanceOf[MapType] val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id" val valueLoopIsNull = if (mapType.valueContainsNull) { @@ -991,9 +997,9 @@ object CatalystToExternalMap { } val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) CatalystToExternalMap( - keyLoopValue, keyFunction(keyLoopVar), - valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), - inputData, collClass) + keyLoopValue, u.keyFunction(keyLoopVar), + valueLoopValue, valueLoopIsNull, u.valueFunction(valueLoopVar), + u.child, u.collClass) } } @@ -1090,15 +1096,9 @@ case class CatalystToExternalMap private( val tupleLoopValue = ctx.freshName("tupleLoopValue") val builderValue = ctx.freshName("builderValue") - val getLength = s"${genInputData.value}.numElements()" - val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") - val getKeyArray = - s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) - val getValueArray = - s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" val getValueLoopVar = CodeGenerator.getValue( valueArray, inputDataType(mapType.valueType), loopIndex) @@ -1147,10 +1147,10 @@ case class CatalystToExternalMap private( ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { - int $dataLength = $getLength; + int $dataLength = ${genInputData.value}.numElements(); $constructBuilder - $getKeyArray - $getValueArray + ArrayData $keyArray = ${genInputData.value}.keyArray(); + ArrayData $valueArray = ${genInputData.value}.valueArray(); int $loopIndex = 0; while ($loopIndex < $dataLength) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index edcdd77908d3..96a6792f52f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -295,7 +295,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) } - ignore("SPARK-19104: map and product combinations") { + test("SPARK-25817: map and product combinations") { // Case classes checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 27b3b3d78d2b..82d3b22a4867 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -164,6 +164,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(ClassData("a", 2)))) } + test("as map of case class - reorder fields by name") { + val df = spark.range(3).select(map(lit(1), struct($"id".cast("int").as("b"), lit("a").as("a")))) + val ds = df.as[Map[Int, ClassData]] + assert(ds.collect() === Array( + Map(1 -> ClassData("a", 0)), + Map(1 -> ClassData("a", 1)), + Map(1 -> ClassData("a", 2)))) + } + test("map") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset(