Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ object ScalaReflection extends ScalaReflection {
walkedTypePath: Seq[String]): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
// TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and
// it's not trivial to support by-name resolution for StructType inside MapType.
case _: MapType => expr
case _ => UpCast(expr, expected, walkedTypePath)
}

Expand All @@ -163,8 +162,8 @@ object ScalaReflection extends ScalaReflection {
val Schema(dataType, nullable) = schemaFor(tpe)

// Assumes we are deserializing the first column of a row.
val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType,
walkedTypePath)
val input = upCastToExpectedType(
GetColumnByOrdinal(0, dataType), dataType, walkedTypePath)

val expr = deserializerFor(tpe, input, walkedTypePath)
if (nullable) {
Expand Down Expand Up @@ -350,10 +349,10 @@ object ScalaReflection extends ScalaReflection {
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t

CatalystToExternalMap(
UnresolvedCatalystToExternalMap(
path,
p => deserializerFor(keyType, p, walkedTypePath),
p => deserializerFor(valueType, p, walkedTypePath),
path,
mirror.runtimeClass(t.typeSymbol.asClass)
)

Expand Down Expand Up @@ -431,8 +430,8 @@ object ScalaReflection extends ScalaReflection {
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil

// The input object to `ExpressionEncoder` is located at first column of an row.
val inputObject = BoundReference(0, dataTypeFor(tpe),
nullable = !tpe.typeSymbol.asClass.isPrimitive)
val isPrimitive = tpe.typeSymbol.asClass.isPrimitive
val inputObject = BoundReference(0, dataTypeFor(tpe), nullable = !isPrimitive)

serializerFor(inputObject, tpe, walkedTypePath)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2384,14 +2384,23 @@ class Analyzer(
case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved =>
inputData.dataType match {
case ArrayType(et, cn) =>
val expr = MapObjects(func, inputData, et, cn, cls) transformUp {
MapObjects(func, inputData, et, cn, cls) transformUp {
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
expr
case other =>
throw new AnalysisException("need an array field but got " + other.catalogString)
}
case u: UnresolvedCatalystToExternalMap if u.child.resolved =>
u.child.dataType match {
case _: MapType =>
CatalystToExternalMap(u) transformUp {
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When u.child is resolved, is there still UnresolvedExtractValue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think so. The UnresolvedExtractValue might appear in CatalystToExternalMap.keyLambdaFunction and valueLambdaFunction

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ResolveReferences might also process that, but it is also good to have them here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I don't quite remember why I did this for MapObjects, so I just follow it here. Maybe we can remove it in a followup PR.

ExtractValue(child, fieldName, resolver)
}
case other =>
throw new AnalysisException("need a map field but got " + other.catalogString)
}
}
validateNestedTupleFields(result)
result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,9 @@ object ExpressionEncoder {
}

val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
val getColumnsByOrdinals = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }
.distinct
assert(getColumnsByOrdinals.size == 1, "object deserializer should have only one " +
s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}")
val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated but to fix minor code style issues in #22749

assert(getColExprs.size == 1, "object deserializer should have only one " +
s"`GetColumnByOrdinal`, but there are ${getColExprs.size}")

val input = GetStructField(GetColumnByOrdinal(0, schema), index)
val newDeserializer = enc.objDeserializer.transformUp {
Expand Down Expand Up @@ -216,7 +215,6 @@ case class ExpressionEncoder[T](
}
nullSafeSerializer match {
case If(_: IsNull, _, s: CreateNamedStruct) => s
case s: CreateNamedStruct => s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

case _ =>
throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@ import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -963,25 +962,32 @@ case class MapObjects private(
}
}

/**
* Similar to [[UnresolvedMapObjects]], this is a placeholder of [[CatalystToExternalMap]].
*
* @param child An expression that when evaluated returns a map object.
* @param keyFunction The function applied on the key collection elements.
* @param valueFunction The function applied on the value collection elements.
* @param collClass The type of the resulting collection.
*/
case class UnresolvedCatalystToExternalMap(
child: Expression,
@transient keyFunction: Expression => Expression,
@transient valueFunction: Expression => Expression,
collClass: Class[_]) extends UnaryExpression with Unevaluable {

override lazy val resolved = false

override def dataType: DataType = ObjectType(collClass)
}

object CatalystToExternalMap {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* Construct an instance of CatalystToExternalMap case class.
*
* @param keyFunction The function applied on the key collection elements.
* @param valueFunction The function applied on the value collection elements.
* @param inputData An expression that when evaluated returns a map object.
* @param collClass The type of the resulting collection.
*/
def apply(
keyFunction: Expression => Expression,
valueFunction: Expression => Expression,
inputData: Expression,
collClass: Class[_]): CatalystToExternalMap = {
def apply(u: UnresolvedCatalystToExternalMap): CatalystToExternalMap = {
val id = curId.getAndIncrement()
val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id"
val mapType = inputData.dataType.asInstanceOf[MapType]
val mapType = u.child.dataType.asInstanceOf[MapType]
val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false)
val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id"
val valueLoopIsNull = if (mapType.valueContainsNull) {
Expand All @@ -991,9 +997,9 @@ object CatalystToExternalMap {
}
val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType)
CatalystToExternalMap(
keyLoopValue, keyFunction(keyLoopVar),
valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar),
inputData, collClass)
keyLoopValue, u.keyFunction(keyLoopVar),
valueLoopValue, valueLoopIsNull, u.valueFunction(valueLoopVar),
u.child, u.collClass)
}
}

Expand Down Expand Up @@ -1090,15 +1096,9 @@ case class CatalystToExternalMap private(
val tupleLoopValue = ctx.freshName("tupleLoopValue")
val builderValue = ctx.freshName("builderValue")

val getLength = s"${genInputData.value}.numElements()"
Copy link
Contributor Author

@cloud-fan cloud-fan Oct 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are unrelated, but is a followup of #16986 to address the remaining code style comments.


val keyArray = ctx.freshName("keyArray")
val valueArray = ctx.freshName("valueArray")
val getKeyArray =
s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();"
val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
val getValueArray =
s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();"
val getValueLoopVar = CodeGenerator.getValue(
valueArray, inputDataType(mapType.valueType), loopIndex)

Expand Down Expand Up @@ -1147,10 +1147,10 @@ case class CatalystToExternalMap private(
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};

if (!${genInputData.isNull}) {
int $dataLength = $getLength;
int $dataLength = ${genInputData.value}.numElements();
$constructBuilder
$getKeyArray
$getValueArray
ArrayData $keyArray = ${genInputData.value}.keyArray();
ArrayData $valueArray = ${genInputData.value}.valueArray();

int $loopIndex = 0;
while ($loopIndex < $dataLength) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong))
}

ignore("SPARK-19104: map and product combinations") {
test("SPARK-25817: map and product combinations") {
// Case classes
checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2)))
checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Seq(ClassData("a", 2))))
}

test("as map of case class - reorder fields by name") {
val df = spark.range(3).select(map(lit(1), struct($"id".cast("int").as("b"), lit("a").as("a"))))
val ds = df.as[Map[Int, ClassData]]
assert(ds.collect() === Array(
Map(1 -> ClassData("a", 0)),
Map(1 -> ClassData("a", 1)),
Map(1 -> ClassData("a", 2))))
}

test("map") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(
Expand Down