Skip to content

Commit d3f87dc

Browse files
JoshRosenmarmbrus
authored andcommitted
[SPARK-10325] Override hashCode() for public Row
This commit fixes an issue where the public SQL `Row` class did not override `hashCode`, causing it to violate the hashCode() + equals() contract. To fix this, I simply ported the `hashCode` implementation from the 1.4.x version of `Row`. Author: Josh Rosen <joshrosen@databricks.com> Closes #8500 from JoshRosen/SPARK-10325 and squashes the following commits: 51ffea1 [Josh Rosen] Override hashCode() for public Row.
1 parent 499e8e1 commit d3f87dc

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import scala.collection.JavaConverters._
21+
import scala.util.hashing.MurmurHash3
2122

2223
import org.apache.spark.sql.catalyst.InternalRow
2324
import org.apache.spark.sql.catalyst.expressions.GenericRow
@@ -410,6 +411,18 @@ trait Row extends Serializable {
410411
true
411412
}
412413

414+
override def hashCode: Int = {
415+
// Using Scala's Seq hash code implementation.
416+
var n = 0
417+
var h = MurmurHash3.seqSeed
418+
val len = length
419+
while (n < len) {
420+
h = MurmurHash3.mix(h, apply(n).##)
421+
n += 1
422+
}
423+
MurmurHash3.finalizeHash(h, n)
424+
}
425+
413426
/* ---------------------- utility methods for Scala ---------------------- */
414427

415428
/**

sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,13 @@ class RowSuite extends SparkFunSuite with SharedSQLContext {
8585
val r2 = Row(Double.NaN)
8686
assert(r1 === r2)
8787
}
88+
89+
test("equals and hashCode") {
90+
val r1 = Row("Hello")
91+
val r2 = Row("Hello")
92+
assert(r1 === r2)
93+
assert(r1.hashCode() === r2.hashCode())
94+
val r3 = Row("World")
95+
assert(r3.hashCode() != r1.hashCode())
96+
}
8897
}

0 commit comments

Comments
 (0)