Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,12 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
*/
override private[spark] def asNullable: UserDefinedType[UserType] = this

override private[sql] def acceptsType(dataType: DataType) =
this.getClass == dataType.getClass
override private[sql] def acceptsType(dataType: DataType) = dataType match {
case other: UserDefinedType[_] =>
this.getClass == other.getClass ||
this.userClass.isAssignableFrom(other.userClass)
case _ => false
}

override def sql: String = sqlType.sql

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql
import scala.beans.{BeanInfo, BeanProperty}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -71,6 +72,77 @@ object UDT {

}

// object and classes to test SPARK-19311

// Trait/Interface for base type
sealed trait IExampleBaseType extends Serializable {
def field: Int
}

// Trait/Interface for derived type
sealed trait IExampleSubType extends IExampleBaseType

// a base class
class ExampleBaseClass(override val field: Int) extends IExampleBaseType

// a derived class
class ExampleSubClass(override val field: Int)
extends ExampleBaseClass(field) with IExampleSubType

// UDT for base class
class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] {

override def sqlType: StructType = {
StructType(Seq(
StructField("intfield", IntegerType, nullable = false)))
}

override def serialize(obj: IExampleBaseType): InternalRow = {
val row = new GenericInternalRow(1)
row.setInt(0, obj.field)
row
}

override def deserialize(datum: Any): IExampleBaseType = {
datum match {
case row: InternalRow =>
require(row.numFields == 1,
"ExampleBaseTypeUDT requires row with length == 1")
val field = row.getInt(0)
new ExampleBaseClass(field)
}
}

override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType]
}

// UDT for derived class
private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] {

override def sqlType: StructType = {
StructType(Seq(
StructField("intfield", IntegerType, nullable = false)))
}

override def serialize(obj: IExampleSubType): InternalRow = {
val row = new GenericInternalRow(1)
row.setInt(0, obj.field)
row
}

override def deserialize(datum: Any): IExampleSubType = {
datum match {
case row: InternalRow =>
require(row.numFields == 1,
"ExampleSubTypeUDT requires row with length == 1")
val field = row.getInt(0)
new ExampleSubClass(field)
}
}

override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
}

class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
import testImplicits._

Expand Down Expand Up @@ -194,4 +266,35 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
// call `collect` to make sure this query can pass analysis.
pointsRDD.as[MyLabeledPoint].map(_.copy(label = 2.0)).collect()
}

test("SPARK-19311: UDFs disregard UDT type hierarchy") {
UDTRegistration.register(classOf[IExampleBaseType].getName,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With SQLUserDefinedType, no need to use UDTRegistration. We can remove this two lines.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i tend to leave them, but remove the @SQLUserDefinedType, so we have a test that uses UDTRegistration

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh. if you worry about that, actually we have UDTRegistrationSuite for test case of UDTRegistration. i am fine to either SQLUserDefinedType or UDTRegistration.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

classOf[ExampleBaseTypeUDT].getName)
UDTRegistration.register(classOf[IExampleSubType].getName,
classOf[ExampleSubTypeUDT].getName)

// UDF that returns a base class object
sqlContext.udf.register("doUDF", (param: Int) => {
new ExampleBaseClass(param)
}: IExampleBaseType)

// UDF that returns a derived class object
sqlContext.udf.register("doSubTypeUDF", (param: Int) => {
new ExampleSubClass(param)
}: IExampleSubType)

// UDF that takes a base class object as parameter
sqlContext.udf.register("doOtherUDF", (obj: IExampleBaseType) => {
obj.field
}: Int)

// this worked already before the fix SPARK-19311:
// return type of doUDF equals parameter type of doOtherUDF
sql("SELECT doOtherUDF(doUDF(41))")

// this one passes only with the fix SPARK-19311:
// return type of doSubUDF is a subtype of the parameter type of doOtherUDF
sql("SELECT doOtherUDF(doSubTypeUDF(42))")
}

}