-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-16062][SPARK-15989][SQL] Fix two bugs of Python-only UDTs #13778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
f26c8dc
cd80f0e
d22dca8
fc9c106
d603cc2
a0b81ba
4c00bb1
1583fe3
65a33b0
1b751af
87a0953
6065364
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -558,6 +558,41 @@ def check_datatype(datatype): | |
| _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) | ||
| self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) | ||
|
|
||
| def test_simple_udt_in_df(self): | ||
| schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) | ||
| df = self.spark.createDataFrame( | ||
| [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], | ||
| schema=schema) | ||
| df.show() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test only fails when using |
||
|
|
||
| def test_nested_udt_in_df(self): | ||
| schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, need its own named unit test (so that it's easier to identify the problem if the test fails) - unit tests should test only one thing, the thing tested here is |
||
| df = self.spark.createDataFrame( | ||
| [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], | ||
| schema=schema) | ||
| df.collect() | ||
|
|
||
| schema = StructType().add("key", LongType()).add("val", | ||
| MapType(LongType(), PythonOnlyUDT())) | ||
| df = self.spark.createDataFrame( | ||
| [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], | ||
| schema=schema) | ||
| df.collect() | ||
|
|
||
| def test_complex_nested_udt_in_df(self): | ||
| from pyspark.sql.functions import udf | ||
|
|
||
| schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) | ||
| df = self.spark.createDataFrame( | ||
| [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], | ||
| schema=schema) | ||
| df.collect() | ||
|
|
||
| gd = df.groupby("key").agg({"val": "collect_list"}) | ||
| gd.collect() | ||
| udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) | ||
| gd.select(udf(*gd)).collect() | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing a unit test for "counterexample 2", nested complex udt struct with mixed types: ArrayType(StructType(simple sql type, udt)) |
||
| def test_infer_schema_with_udt(self): | ||
| from pyspark.sql.tests import ExamplePoint, ExamplePointUDT | ||
| row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,6 +206,7 @@ object RowEncoder { | |
| case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) | ||
| case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) | ||
| case _: StructType => ObjectType(classOf[Row]) | ||
| case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) | ||
| case udt: UserDefinedType[_] => ObjectType(udt.userClass) | ||
| } | ||
|
|
||
|
|
@@ -220,9 +221,15 @@ object RowEncoder { | |
| CreateExternalRow(fields, schema) | ||
| } | ||
|
|
||
| private def deserializerFor(input: Expression): Expression = input.dataType match { | ||
| private def deserializerFor(input: Expression): Expression = { | ||
| deserializerFor(input, input.dataType) | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems that this method is never used?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uh? It is the original deserializerFor method and is used below and above.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, sorry... Confused by the split diff view... |
||
|
|
||
| private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { | ||
| case dt if ScalaReflection.isNativeType(dt) => input | ||
|
|
||
| case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) | ||
|
|
||
| case udt: UserDefinedType[_] => | ||
| val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) | ||
| val udtClass: Class[_] = if (annotation != null) { | ||
|
|
@@ -262,7 +269,7 @@ object RowEncoder { | |
| case ArrayType(et, nullable) => | ||
| val arrayData = | ||
| Invoke( | ||
| MapObjects(deserializerFor(_), input, et), | ||
| MapObjects(deserializerFor(_), input, et, Some(dataType)), | ||
| "array", | ||
| ObjectType(classOf[Array[_]])) | ||
| StaticInvoke( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -349,11 +349,12 @@ object MapObjects { | |
| def apply( | ||
| function: Expression => Expression, | ||
| inputData: Expression, | ||
| elementType: DataType): MapObjects = { | ||
| elementType: DataType, | ||
| inputDataType: Option[DataType] = None): MapObjects = { | ||
|
||
| val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() | ||
| val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() | ||
| val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) | ||
| MapObjects(loopVar, function(loopVar), inputData) | ||
| MapObjects(loopVar, function(loopVar), inputData, inputDataType) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -370,11 +371,13 @@ object MapObjects { | |
| * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function | ||
| * to handle collection elements. | ||
| * @param inputData An expression that when evaluated returns a collection object. | ||
| * @param inputDataType The dataType of inputData. Optional. | ||
| */ | ||
| case class MapObjects private( | ||
| loopVar: LambdaVariable, | ||
| lambdaFunction: Expression, | ||
| inputData: Expression) extends Expression with NonSQLExpression { | ||
| inputData: Expression, | ||
| inputDataType: Option[DataType]) extends Expression with NonSQLExpression { | ||
|
|
||
| override def nullable: Boolean = true | ||
|
|
||
|
|
@@ -427,8 +430,9 @@ case class MapObjects private( | |
| case _ => "" | ||
| } | ||
|
|
||
| val inputDT = inputDataType.getOrElse(inputData.dataType) | ||
|
||
|
|
||
| val (getLength, getLoopVar) = inputData.dataType match { | ||
| val (getLength, getLoopVar) = inputDT match { | ||
| case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => | ||
| s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" | ||
| case ObjectType(cls) if cls.isArray => | ||
|
|
@@ -442,7 +446,7 @@ case class MapObjects private( | |
| s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" | ||
| } | ||
|
|
||
| val loopNullCheck = inputData.dataType match { | ||
| val loopNullCheck = inputDT match { | ||
| case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" | ||
| // The element of primitive array will never be null. | ||
| case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in its own test method - it's no longer merely a
test_udtbut rather atest_simple_udt_in_df.