Skip to content
35 changes: 35 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,41 @@ def check_datatype(datatype):
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))

def test_simple_udt_in_df(self):
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in its own test method - it's no longer merely a test_udt but rather a test_simple_udt_in_df.

df = self.spark.createDataFrame(
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
schema=schema)
df.show()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataFrame.show() gives unnecessary stringification, so this test ends up testing unnecessary stuff (in fact it would fail if the UDT didn't have __str__. I would use collect() to force materialization instead.

Copy link
Member Author

@viirya viirya Jun 22, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only fails when using show() as I mentioned on the JIRA SPARK-16062.


def test_nested_udt_in_df(self):
schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, need its own named unit test (so that it's easier to identify the problem if the test fails) - unit tests should test only one thing, the thing tested here is test_nested_udt_in_df (perhaps also worthwhile to check Map works?)

df = self.spark.createDataFrame(
[(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
schema=schema)
df.collect()

schema = StructType().add("key", LongType()).add("val",
MapType(LongType(), PythonOnlyUDT()))
df = self.spark.createDataFrame(
[(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
schema=schema)
df.collect()

def test_complex_nested_udt_in_df(self):
from pyspark.sql.functions import udf

schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
df = self.spark.createDataFrame(
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
schema=schema)
df.collect()

gd = df.groupby("key").agg({"val": "collect_list"})
gd.collect()
udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
gd.select(udf(*gd)).collect()

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a unit test for "counterexample 2", nested complex udt struct with mixed types: ArrayType(StructType(simple sql type, udt))

def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ object RowEncoder {
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType)
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
}

Expand All @@ -220,9 +221,15 @@ object RowEncoder {
CreateExternalRow(fields, schema)
}

private def deserializerFor(input: Expression): Expression = input.dataType match {
private def deserializerFor(input: Expression): Expression = {
deserializerFor(input, input.dataType)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that this method is never used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh? It is the original deserializerFor method and is used below and above.

Copy link
Contributor

@liancheng liancheng Jun 20, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry... Confused by the split diff view...


private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match {
case dt if ScalaReflection.isNativeType(dt) => input

case p: PythonUserDefinedType => deserializerFor(input, p.sqlType)

case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
val udtClass: Class[_] = if (annotation != null) {
Expand Down Expand Up @@ -262,7 +269,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
MapObjects(deserializerFor(_), input, et),
MapObjects(deserializerFor(_), input, et, Some(dataType)),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,12 @@ object MapObjects {
def apply(
function: Expression => Expression,
inputData: Expression,
elementType: DataType): MapObjects = {
elementType: DataType,
inputDataType: Option[DataType] = None): MapObjects = {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got quite a few problems with this default:

  1. It looks like it's here to avoid coding work in most places, not because it's an obvious value for apply() to take on
  2. The additional parameter is completely undocumented and call sites have no mention of it.

Also, the use of option in the case class constructor is a bit obtuse.

Here is my suggestion:

  1. Remove the default.
  2. If the option in apply() is None, then pass inputDataType.getOrElse(inputData.dataType) as the inputDataType : DataType to the case class constructor, which uses the parameter data type without any hidden logic (as it does now).
  3. Document the fact that supplying None in apply() triggers this kind of inference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At most scenarios, we don't need to specify inputDataType. So I think it is indeed the obvious value. I agreed that the use of option here is not good. I will change it and have it documented.

val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
MapObjects(loopVar, function(loopVar), inputData)
MapObjects(loopVar, function(loopVar), inputData, inputDataType)
}
}

Expand All @@ -370,11 +371,13 @@ object MapObjects {
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param inputDataType The dataType of inputData. Optional.
*/
case class MapObjects private(
loopVar: LambdaVariable,
lambdaFunction: Expression,
inputData: Expression) extends Expression with NonSQLExpression {
inputData: Expression,
inputDataType: Option[DataType]) extends Expression with NonSQLExpression {

override def nullable: Boolean = true

Expand Down Expand Up @@ -427,8 +430,9 @@ case class MapObjects private(
case _ => ""
}

val inputDT = inputDataType.getOrElse(inputData.dataType)
Copy link
Member Author

@viirya viirya Jun 28, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan I think there is no way to easily catch python udt before MapObjects. The approach I use now is to pass a datatype (python udt's sqlType) to MapObjects.


val (getLength, getLoopVar) = inputData.dataType match {
val (getLength, getLoopVar) = inputDT match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
case ObjectType(cls) if cls.isArray =>
Expand All @@ -442,7 +446,7 @@ case class MapObjects private(
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
}

val loopNullCheck = inputData.dataType match {
val loopNullCheck = inputDT match {
case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
Expand Down