Skip to content
Prev Previous commit
Address comments.
  • Loading branch information
viirya committed Jul 12, 2016
commit 6065364da697cd29f9b31179063e6cf604aa25ef
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
MapObjects(deserializerFor(_), input, et, dataType),
MapObjects(deserializerFor(_), input, et),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,33 +360,7 @@ object MapObjects {
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, None)
}

/**
* Construct an instance of MapObjects case class.
*
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
* @param inputDataType The explicitly given data type of inputData to override the
* data type inferred from inputData (i.e., inputData.dataType).
* When Python UDT whose sqlType is an array, the deserializer
* expression will apply MapObjects on it. However, as the data type
* of inputData is Python UDT, which is not an expected array type
* in MapObjects. In this case, we need to explicitly use
* Python UDT's sqlType as data type.
*/
def apply(
function: Expression => Expression,
inputData: Expression,
elementType: DataType,
inputDataType: DataType): MapObjects = {
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
Some(inputDataType))
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
}
}

Expand All @@ -407,16 +381,13 @@ object MapObjects {
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param inputDataType The optional dataType of inputData. If it is None, the default behavior is
* to use the resolved data type of the inputData.
*/
case class MapObjects private(
loopValue: String,
loopIsNull: String,
loopVarDataType: DataType,
lambdaFunction: Expression,
inputData: Expression,
inputDataType: Option[DataType]) extends Expression with NonSQLExpression {
inputData: Expression) extends Expression with NonSQLExpression {

override def nullable: Boolean = true

Expand Down Expand Up @@ -469,7 +440,14 @@ case class MapObjects private(
case _ => ""
}

val (getLength, getLoopVar) = inputDataType.getOrElse(inputData.dataType) match {
// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
val inputDataType = inputData.dataType match {
case p: PythonUserDefinedType => p.sqlType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about Scala UDT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be another UDT inside p.sqlType

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala UDT is already cover by deserializerFor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to handle python udf different from scala udf?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python udf has no userClass. So regular handling of scala udf will be failed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fixed by https://github.com/apache/spark/pull/13778/files#diff-47e9c0787b1c455e5bd4ad7b65df3436R209 . Can you double check it? If we revert this change, will test fail again?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. let me check it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The test will be failed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to catch the python udt before passing it to MapObjects? I'm kind of worried about leaking python udt to a lot of places, we should handle them just in a few places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. let me try it.

case _ => inputData.dataType
}

val (getLength, getLoopVar) = inputDataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
case ObjectType(cls) if cls.isArray =>
Expand All @@ -483,7 +461,7 @@ case class MapObjects private(
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
}

val loopNullCheck = inputDataType.getOrElse(inputData.dataType) match {
val loopNullCheck = inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
Expand Down