Skip to content
35 changes: 35 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,41 @@ def check_datatype(datatype):
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))

def test_simple_udt_in_df(self):
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in its own test method - it's no longer merely a test_udt but rather a test_simple_udt_in_df.

df = self.spark.createDataFrame(
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
schema=schema)
df.show()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataFrame.show() gives unnecessary stringification, so this test ends up testing unnecessary stuff (in fact it would fail if the UDT didn't have __str__. I would use collect() to force materialization instead.

Copy link
Member Author

@viirya viirya Jun 22, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only fails when using show() as I mentioned on the JIRA SPARK-16062.


def test_nested_udt_in_df(self):
schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, need its own named unit test (so that it's easier to identify the problem if the test fails) - unit tests should test only one thing, the thing tested here is test_nested_udt_in_df (perhaps also worthwhile to check Map works?)

df = self.spark.createDataFrame(
[(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
schema=schema)
df.collect()

schema = StructType().add("key", LongType()).add("val",
MapType(LongType(), PythonOnlyUDT()))
df = self.spark.createDataFrame(
[(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
schema=schema)
df.collect()

def test_complex_nested_udt_in_df(self):
from pyspark.sql.functions import udf

schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
df = self.spark.createDataFrame(
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
schema=schema)
df.collect()

gd = df.groupby("key").agg({"val": "collect_list"})
gd.collect()
udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
gd.select(udf(*gd)).collect()

def test_udt_with_none(self):
df = self.spark.range(0, 10, 1, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ object RowEncoder {
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType)
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
}

Expand All @@ -220,9 +221,15 @@ object RowEncoder {
CreateExternalRow(fields, schema)
}

private def deserializerFor(input: Expression): Expression = input.dataType match {
private def deserializerFor(input: Expression): Expression = {
deserializerFor(input, input.dataType)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that this method is never used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh? It is the original deserializerFor method and is used below and above.

Copy link
Contributor

@liancheng liancheng Jun 20, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry... Confused by the split diff view...


private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match {
case dt if ScalaReflection.isNativeType(dt) => input

case p: PythonUserDefinedType => deserializerFor(input, p.sqlType)

case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
val udtClass: Class[_] = if (annotation != null) {
Expand Down Expand Up @@ -262,7 +269,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
MapObjects(deserializerFor(_), input, et),
MapObjects(deserializerFor(_), input, et, dataType),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,47 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
object MapObjects {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* Construct an instance of MapObjects case class.
*
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
*/
def apply(
function: Expression => Expression,
inputData: Expression,
elementType: DataType): MapObjects = {
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possibly that inputData is unresolved yet. We can't just pass in the data type of inputData. So I still make inputDataType as Option[DataType] below.

}

/**
* Construct an instance of MapObjects case class.
*
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
* @param inputDataType The explicitly given data type of inputData to override the
* data type inferred from inputData (i.e., inputData.dataType).
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding in the documentation why this would ever be necessary (i.e., in the array of python UDT case, what's wrong with inputData.dataType)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Added.

* When Python UDT whose sqlType is an array, the deserializer
* expression will apply MapObjects on it. However, as the data type
* of inputData is Python UDT, which is not an expected array type
* in MapObjects. In this case, we need to explicitly use
* Python UDT's sqlType as data type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we have to mention python udt in MapObjects anyway, I think it makes more sense to add the python udt handling in MapOjects directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the early commit? I remember it is the first approach I take.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think it exposes python udt to MapObjects?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But now we expose too. Readers have to know about python udt to understand this code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, ok. Let me update it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can hide python udt from MapObjects entirely, it worth to do. But looks like we can't, and I think then it makes more sense to expose python udt more explicitly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making sense. I will update it later.

*/
def apply(
function: Expression => Expression,
inputData: Expression,
elementType: DataType,
inputDataType: DataType): MapObjects = {
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
Some(inputDataType))
}
}

Expand All @@ -374,13 +407,16 @@ object MapObjects {
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param inputDataType The optional dataType of inputData. If it is None, the default behavior is
* to use the resolved data type of the inputData.
*/
case class MapObjects private(
loopValue: String,
loopIsNull: String,
loopVarDataType: DataType,
lambdaFunction: Expression,
inputData: Expression) extends Expression with NonSQLExpression {
inputData: Expression,
inputDataType: Option[DataType]) extends Expression with NonSQLExpression {

override def nullable: Boolean = true

Expand Down Expand Up @@ -433,8 +469,7 @@ case class MapObjects private(
case _ => ""
}


val (getLength, getLoopVar) = inputData.dataType match {
val (getLength, getLoopVar) = inputDataType.getOrElse(inputData.dataType) match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
case ObjectType(cls) if cls.isArray =>
Expand All @@ -448,7 +483,7 @@ case class MapObjects private(
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
}

val loopNullCheck = inputData.dataType match {
val loopNullCheck = inputDataType.getOrElse(inputData.dataType) match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
Expand Down