From 7ea9aa636f430a30b8d83ed2dda954fd06347d79 Mon Sep 17 00:00:00 2001 From: gmoehler Date: Fri, 20 Jan 2017 16:16:53 +0100 Subject: [PATCH 1/7] fix UDT hierarchy issue https://issues.apache.org/jira/browse/SPARK-19311 --- .../org/apache/spark/sql/types/UserDefinedType.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index c33219c95b50..5a944e763e09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -78,8 +78,12 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa */ override private[spark] def asNullable: UserDefinedType[UserType] = this - override private[sql] def acceptsType(dataType: DataType) = - this.getClass == dataType.getClass + override private[sql] def acceptsType(dataType: DataType) = dataType match { + case other: UserDefinedType[_] => + this.getClass == other.getClass || + this.userClass.isAssignableFrom(other.userClass) + case _ => false + } override def sql: String = sqlType.sql From fad9f0e1bc98ad5e7ee2c8dfc726ddc979204f11 Mon Sep 17 00:00:00 2001 From: gmoehler Date: Mon, 23 Jan 2017 13:51:09 +0100 Subject: [PATCH 2/7] add test case for SPARK-19311. Test case still fails - cause from elsewhere? Test case failure is: - SPARK-19311: UDFs disregard UDT type hierarchy *** FAILED *** org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Max iterations (100) reached for batch Resolution, tree: Project [UDF(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(cast(UDF(41) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType) as exampleBaseType)) AS UDF(UDF(41))#166] +- SubqueryAlias tmp_table +- Project [_1#157 AS id#160, _2#158 AS saying#161] +- LocalRelation [_1#157, _2#158] at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:105) at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:74) at scala.collection.immutable.List.foreach(List.scala:381) at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:74) at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:64) at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:62) at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:48) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:64) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:592) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:61) --- .../spark/sql/UserDefinedTypeSuite.scala | 123 +++++++++++++++++- 1 file changed, 121 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 474f17ff7afb..2821ddb37124 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ @@ -71,6 +71,94 @@ object UDT { } +// object and classes to test SPARK-19311 + + // Trait/Interface for base type + @SQLUserDefinedType(udt = classOf[ExampleBaseTypeUDT]) + sealed trait IExampleBaseType extends Serializable { + def field: Int + } + + // Trait/Interface for derived type + @SQLUserDefinedType(udt = classOf[ExampleSubTypeUDT]) + sealed trait IExampleSubType extends IExampleBaseType + + // a base class + class ExampleBaseClass(override val field: Int) extends IExampleBaseType { + override def toString: String = field.toString + + } + + // a derived class + class ExampleSubClass(override val field: Int) + extends ExampleBaseClass(field) with IExampleSubType + + // UDT for base class + private[spark] class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] { + + override def sqlType: StructType = { + StructType(Seq( + StructField("intfield", IntegerType, nullable = false))) + } + + override def serialize(obj: IExampleBaseType): InternalRow = { + val row = new GenericInternalRow(1) + row.setInt(0, obj.field) + row + } + + override def deserialize(datum: Any): IExampleBaseType = { + datum match { + case row: InternalRow => + require(row.numFields == 1, + s"VectorUDT.deserialize given row with length " + + s"${row.numFields} but requires length == 1") + val field = row.getInt(0) + new ExampleBaseClass(field) + } + } + + override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType] + override def hashCode(): Int = classOf[ExampleBaseTypeUDT].getName.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[IExampleBaseType] + override def typeName: String = "exampleBaseType" + private[spark] override def asNullable: ExampleBaseTypeUDT = this + } + + // UDT for derived class + private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] { + + override def sqlType: StructType = { + StructType(Seq( + StructField("intfield", IntegerType, nullable = false))) + } + + override def serialize(obj: IExampleSubType): InternalRow = { + + val row = new GenericInternalRow(1) + row.setInt(0, obj.field) + row + } + + override def deserialize(datum: Any): IExampleSubType = { + datum match { + case row: InternalRow => + require(row.numFields == 1, + s"VectorUDT.deserialize given row with length " + + s"${row.numFields} but requires length == 1") + val field = row.getInt(0) + new ExampleSubClass(field) + } + } + + override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] + override def hashCode(): Int = classOf[ExampleSubTypeUDT].getName.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[IExampleSubType] + override def typeName: String = "exampleFirstSubType" + private[spark] override def asNullable: ExampleSubTypeUDT = this + } + + class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { import testImplicits._ @@ -194,4 +282,35 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT // call `collect` to make sure this query can pass analysis. pointsRDD.as[MyLabeledPoint].map(_.copy(label = 2.0)).collect() } + + test("SPARK-19311: UDFs disregard UDT type hierarchy") { + UDTRegistration.register(classOf[IExampleBaseType].getName, + classOf[ExampleBaseTypeUDT].getName) + UDTRegistration.register(classOf[IExampleSubType].getName, + classOf[ExampleSubTypeUDT].getName) + + // UDF that returns a base class object + sqlContext.udf.register("doUDF", (param: Int) => { + new ExampleBaseClass(param) + }: IExampleBaseType) + + // UDF that returns a derived class object + sqlContext.udf.register("doSubTypeUDF", (param: Int) => { + new ExampleSubClass(param) + }: IExampleSubType) + + // UDF that takes a base class object as parameter + sqlContext.udf.register("doOtherUDF", (obj: IExampleBaseType) => { + obj.field + }: Int) + + // this worked already before the fix SPARK-19311: + // return type of doFirstUDF equals parameter type of doOtherUDF + sql("SELECT doOtherUDF(doUDF(41))") + + // this one passes only with the fix SPARK-19311: + // return type of doFirstSubUDF is a subtype of the parameter type of doOtherUDF + sql("SELECT doOtherUDF(ddSubTypeUDF(42))") + } + } From fb261d7e329b0da5970fd7f6f7eaa2ee5064d89b Mon Sep 17 00:00:00 2001 From: gmoehler Date: Mon, 23 Jan 2017 15:16:02 +0100 Subject: [PATCH 3/7] minor corrections of unit test case --- .../org/apache/spark/sql/UserDefinedTypeSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 2821ddb37124..a94702829113 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -120,7 +120,7 @@ object UDT { override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType] override def hashCode(): Int = classOf[ExampleBaseTypeUDT].getName.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[IExampleBaseType] + override def equals(other: Any): Boolean = other.isInstanceOf[ExampleBaseTypeUDT] override def typeName: String = "exampleBaseType" private[spark] override def asNullable: ExampleBaseTypeUDT = this } @@ -153,8 +153,8 @@ object UDT { override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] override def hashCode(): Int = classOf[ExampleSubTypeUDT].getName.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[IExampleSubType] - override def typeName: String = "exampleFirstSubType" + override def equals(other: Any): Boolean = other.isInstanceOf[ExampleSubTypeUDT] + override def typeName: String = "exampleSubType" private[spark] override def asNullable: ExampleSubTypeUDT = this } @@ -305,12 +305,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT }: Int) // this worked already before the fix SPARK-19311: - // return type of doFirstUDF equals parameter type of doOtherUDF + // return type of doUDF equals parameter type of doOtherUDF sql("SELECT doOtherUDF(doUDF(41))") // this one passes only with the fix SPARK-19311: - // return type of doFirstSubUDF is a subtype of the parameter type of doOtherUDF - sql("SELECT doOtherUDF(ddSubTypeUDF(42))") + // return type of doSubUDF is a subtype of the parameter type of doOtherUDF + sql("SELECT doOtherUDF(doSubTypeUDF(42))") } } From 6b6b773a4df1ee017f00b9ca43a76a429ad418de Mon Sep 17 00:00:00 2001 From: gmoehler Date: Mon, 23 Jan 2017 15:27:45 +0100 Subject: [PATCH 4/7] re-organize imports to satisfy Scalastyle checks --- .../scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a94702829113..d5c84ff80d01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql -import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import scala.beans.{BeanInfo, BeanProperty} + @BeanInfo private[sql] case class MyLabeledPoint( @BeanProperty label: Double, From 4ae05559d9a1b343e61e87ad0e4841c497c13382 Mon Sep 17 00:00:00 2001 From: gmoehler Date: Mon, 23 Jan 2017 15:49:30 +0100 Subject: [PATCH 5/7] re-organize imports to satisfy Scalastyle checks & add correct indendation & lf --- .../spark/sql/UserDefinedTypeSuite.scala | 146 ++++++++++-------- 1 file changed, 78 insertions(+), 68 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index d5c84ff80d01..428f87a0bbd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql +import scala.beans.{BeanInfo, BeanProperty} + import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import scala.beans.{BeanInfo, BeanProperty} @BeanInfo private[sql] case class MyLabeledPoint( @@ -74,93 +75,102 @@ object UDT { // object and classes to test SPARK-19311 - // Trait/Interface for base type - @SQLUserDefinedType(udt = classOf[ExampleBaseTypeUDT]) - sealed trait IExampleBaseType extends Serializable { - def field: Int - } +// Trait/Interface for base type +@SQLUserDefinedType(udt = classOf[ExampleBaseTypeUDT]) +sealed trait IExampleBaseType extends Serializable { + def field: Int +} - // Trait/Interface for derived type - @SQLUserDefinedType(udt = classOf[ExampleSubTypeUDT]) - sealed trait IExampleSubType extends IExampleBaseType +// Trait/Interface for derived type +@SQLUserDefinedType(udt = classOf[ExampleSubTypeUDT]) +sealed trait IExampleSubType extends IExampleBaseType - // a base class - class ExampleBaseClass(override val field: Int) extends IExampleBaseType { - override def toString: String = field.toString +// a base class +class ExampleBaseClass(override val field: Int) extends IExampleBaseType { + override def toString: String = field.toString - } +} - // a derived class - class ExampleSubClass(override val field: Int) - extends ExampleBaseClass(field) with IExampleSubType +// a derived class +class ExampleSubClass(override val field: Int) + extends ExampleBaseClass(field) with IExampleSubType - // UDT for base class - private[spark] class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] { +// UDT for base class +private[spark] class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] { - override def sqlType: StructType = { - StructType(Seq( - StructField("intfield", IntegerType, nullable = false))) - } + override def sqlType: StructType = { + StructType(Seq( + StructField("intfield", IntegerType, nullable = false))) + } - override def serialize(obj: IExampleBaseType): InternalRow = { - val row = new GenericInternalRow(1) - row.setInt(0, obj.field) - row - } + override def serialize(obj: IExampleBaseType): InternalRow = { + val row = new GenericInternalRow(1) + row.setInt(0, obj.field) + row + } - override def deserialize(datum: Any): IExampleBaseType = { - datum match { - case row: InternalRow => - require(row.numFields == 1, - s"VectorUDT.deserialize given row with length " + - s"${row.numFields} but requires length == 1") - val field = row.getInt(0) - new ExampleBaseClass(field) - } + override def deserialize(datum: Any): IExampleBaseType = { + datum match { + case row: InternalRow => + require(row.numFields == 1, + s"VectorUDT.deserialize given row with length " + + s"${row.numFields} but requires length == 1") + val field = row.getInt(0) + new ExampleBaseClass(field) } + } + + override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType] - override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType] - override def hashCode(): Int = classOf[ExampleBaseTypeUDT].getName.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[ExampleBaseTypeUDT] - override def typeName: String = "exampleBaseType" - private[spark] override def asNullable: ExampleBaseTypeUDT = this + override def hashCode(): Int = classOf[ExampleBaseTypeUDT].getName.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[ExampleBaseTypeUDT] + + override def typeName: String = "exampleBaseType" + + private[spark] override def asNullable: ExampleBaseTypeUDT = this +} + +// UDT for derived class +private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] { + + override def sqlType: StructType = { + StructType(Seq( + StructField("intfield", IntegerType, nullable = false))) } - // UDT for derived class - private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] { + override def serialize(obj: IExampleSubType): InternalRow = { + + val row = new GenericInternalRow(1) + row.setInt(0, obj.field) + row + } - override def sqlType: StructType = { - StructType(Seq( - StructField("intfield", IntegerType, nullable = false))) + override def deserialize(datum: Any): IExampleSubType = { + datum match { + case row: InternalRow => + require(row.numFields == 1, + s"VectorUDT.deserialize given row with length " + + s"${row.numFields} but requires length == 1") + val field = row.getInt(0) + new ExampleSubClass(field) } + } - override def serialize(obj: IExampleSubType): InternalRow = { + override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] - val row = new GenericInternalRow(1) - row.setInt(0, obj.field) - row - } + override def hashCode(): Int = classOf[ExampleSubTypeUDT].getName.hashCode() - override def deserialize(datum: Any): IExampleSubType = { - datum match { - case row: InternalRow => - require(row.numFields == 1, - s"VectorUDT.deserialize given row with length " + - s"${row.numFields} but requires length == 1") - val field = row.getInt(0) - new ExampleSubClass(field) - } - } + override def equals(other: Any): Boolean = other.isInstanceOf[ExampleSubTypeUDT] - override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] - override def hashCode(): Int = classOf[ExampleSubTypeUDT].getName.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[ExampleSubTypeUDT] - override def typeName: String = "exampleSubType" - private[spark] override def asNullable: ExampleSubTypeUDT = this - } + override def typeName: String = "exampleSubType" + + private[spark] override def asNullable: ExampleSubTypeUDT = this +} class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { + import testImplicits._ private lazy val pointsRDD = Seq( From 7aed9a4fada263785ce1d81acb31073ef7a401fd Mon Sep 17 00:00:00 2001 From: gmoehler Date: Tue, 24 Jan 2017 11:11:50 +0100 Subject: [PATCH 6/7] worked in comments regarding style & not required overrides --- .../spark/sql/UserDefinedTypeSuite.scala | 31 ++----------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 428f87a0bbd0..9d8c743343fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ - @BeanInfo private[sql] case class MyLabeledPoint( @BeanProperty label: Double, @@ -76,19 +75,16 @@ object UDT { // object and classes to test SPARK-19311 // Trait/Interface for base type -@SQLUserDefinedType(udt = classOf[ExampleBaseTypeUDT]) sealed trait IExampleBaseType extends Serializable { def field: Int } // Trait/Interface for derived type -@SQLUserDefinedType(udt = classOf[ExampleSubTypeUDT]) sealed trait IExampleSubType extends IExampleBaseType // a base class class ExampleBaseClass(override val field: Int) extends IExampleBaseType { override def toString: String = field.toString - } // a derived class @@ -96,7 +92,7 @@ class ExampleSubClass(override val field: Int) extends ExampleBaseClass(field) with IExampleSubType // UDT for base class -private[spark] class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] { +class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] { override def sqlType: StructType = { StructType(Seq( @@ -113,22 +109,13 @@ private[spark] class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType datum match { case row: InternalRow => require(row.numFields == 1, - s"VectorUDT.deserialize given row with length " + - s"${row.numFields} but requires length == 1") + "ExampleBaseTypeUDT requires row with length == 1") val field = row.getInt(0) new ExampleBaseClass(field) } } override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType] - - override def hashCode(): Int = classOf[ExampleBaseTypeUDT].getName.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[ExampleBaseTypeUDT] - - override def typeName: String = "exampleBaseType" - - private[spark] override def asNullable: ExampleBaseTypeUDT = this } // UDT for derived class @@ -140,7 +127,6 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] } override def serialize(obj: IExampleSubType): InternalRow = { - val row = new GenericInternalRow(1) row.setInt(0, obj.field) row @@ -150,27 +136,16 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] datum match { case row: InternalRow => require(row.numFields == 1, - s"VectorUDT.deserialize given row with length " + - s"${row.numFields} but requires length == 1") + "ExampleSubTypeUDT requires row with length == 1") val field = row.getInt(0) new ExampleSubClass(field) } } override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] - - override def hashCode(): Int = classOf[ExampleSubTypeUDT].getName.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[ExampleSubTypeUDT] - - override def typeName: String = "exampleSubType" - - private[spark] override def asNullable: ExampleSubTypeUDT = this } - class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { - import testImplicits._ private lazy val pointsRDD = Seq( From 6c16760f0b2e5aba864366c03888666ba584d38a Mon Sep 17 00:00:00 2001 From: gmoehler Date: Wed, 25 Jan 2017 10:20:41 +0100 Subject: [PATCH 7/7] remove unnecessary toString() method --- .../scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 9d8c743343fd..ea4a8ee7ff72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -83,9 +83,7 @@ sealed trait IExampleBaseType extends Serializable { sealed trait IExampleSubType extends IExampleBaseType // a base class -class ExampleBaseClass(override val field: Int) extends IExampleBaseType { - override def toString: String = field.toString -} +class ExampleBaseClass(override val field: Int) extends IExampleBaseType // a derived class class ExampleSubClass(override val field: Int)