Skip to content

Commit 0ec17c9

Browse files
Ngone51dongjoon-hyun
authored andcommitted
[SPARK-32090][SQL] Improve UserDefinedType.equal() to make it be symmetrical
### What changes were proposed in this pull request? This PR fix `UserDefinedType.equal()` by comparing the UDT class instead of checking `acceptsType()`. ### Why are the changes needed? It's weird that equality comparison between two UDT types can have different result by switching the order: ```scala // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass val udt1 = new ExampleBaseTypeUDT val udt2 = new ExampleSubTypeUDT println(udt1 == udt2) // true println(udt2 == udt1) // false ``` ### Does this PR introduce _any_ user-facing change? Yes. Before: ```scala // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass val udt1 = new ExampleBaseTypeUDT val udt2 = new ExampleSubTypeUDT println(udt1 == udt2) // true println(udt2 == udt1) // false ``` After: ```scala // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass val udt1 = new ExampleBaseTypeUDT val udt2 = new ExampleSubTypeUDT println(udt1 == udt2) // false println(udt2 == udt1) // false ``` ### How was this patch tested? Added a unit test. Closes apache#28923 from Ngone51/fix-udt-equal. Authored-by: yi.wu <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent f944603 commit 0ec17c9

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
9090
override def hashCode(): Int = getClass.hashCode()
9191

9292
override def equals(other: Any): Boolean = other match {
93-
case that: UserDefinedType[_] => this.acceptsType(that)
93+
case that: UserDefinedType[_] => this.getClass == that.getClass
9494
case _ => false
9595
}
9696

sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,24 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque
134134
MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))),
135135
MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.3, 3.0)))).toDF()
136136

137+
138+
test("SPARK-32090: equal") {
139+
val udt1 = new ExampleBaseTypeUDT
140+
val udt2 = new ExampleSubTypeUDT
141+
val udt3 = new ExampleSubTypeUDT
142+
assert(udt1 !== udt2)
143+
assert(udt2 !== udt1)
144+
assert(udt2 === udt3)
145+
assert(udt3 === udt2)
146+
}
147+
148+
test("SPARK-32090: acceptsType") {
149+
val udt1 = new ExampleBaseTypeUDT
150+
val udt2 = new ExampleSubTypeUDT
151+
assert(udt1.acceptsType(udt2))
152+
assert(!udt2.acceptsType(udt1))
153+
}
154+
137155
test("register user type: MyDenseVector for MyLabeledPoint") {
138156
val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
139157
val labelsArrays: Array[Double] = labels.collect()

0 commit comments

Comments
 (0)