Skip to content

Commit 9d4e621

Browse files
Davies Liudavies
authored andcommitted
[SPARK-16802] [SQL] fix overflow in LongToUnsafeRowMap
## What changes were proposed in this pull request? This patch fix the overflow in LongToUnsafeRowMap when the range of key is very wide (the key is much much smaller then minKey, for example, key is Long.MinValue, minKey is > 0). ## How was this patch tested? Added regression test (also for SPARK-16740) Author: Davies Liu <[email protected]> Closes #14464 from davies/fix_overflow.
1 parent 9d7a474 commit 9d4e621

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
459459
*/
460460
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
461461
if (isDense) {
462-
val idx = (key - minKey).toInt
463-
if (idx >= 0 && key <= maxKey && array(idx) > 0) {
464-
return getRow(array(idx), resultRow)
462+
if (key >= minKey && key <= maxKey) {
463+
val value = array((key - minKey).toInt)
464+
if (value > 0) {
465+
return getRow(value, resultRow)
466+
}
465467
}
466468
} else {
467469
var pos = firstSlot(key)
@@ -497,9 +499,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
497499
*/
498500
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
499501
if (isDense) {
500-
val idx = (key - minKey).toInt
501-
if (idx >=0 && key <= maxKey && array(idx) > 0) {
502-
return valueIter(array(idx), resultRow)
502+
if (key >= minKey && key <= maxKey) {
503+
val value = array((key - minKey).toInt)
504+
if (value > 0) {
505+
return valueIter(value, resultRow)
506+
}
503507
}
504508
} else {
505509
var pos = firstSlot(key)

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,51 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
152152
}
153153
}
154154

155+
test("LongToUnsafeRowMap with very wide range") {
156+
val taskMemoryManager = new TaskMemoryManager(
157+
new StaticMemoryManager(
158+
new SparkConf().set("spark.memory.offHeap.enabled", "false"),
159+
Long.MaxValue,
160+
Long.MaxValue,
161+
1),
162+
0)
163+
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
164+
165+
{
166+
// SPARK-16740
167+
val keys = Seq(0L, Long.MaxValue, Long.MaxValue)
168+
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
169+
keys.foreach { k =>
170+
map.append(k, unsafeProj(InternalRow(k)))
171+
}
172+
map.optimize()
173+
val row = unsafeProj(InternalRow(0L)).copy()
174+
keys.foreach { k =>
175+
assert(map.getValue(k, row) eq row)
176+
assert(row.getLong(0) === k)
177+
}
178+
map.free()
179+
}
180+
181+
182+
{
183+
// SPARK-16802
184+
val keys = Seq(Long.MaxValue, Long.MaxValue - 10)
185+
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
186+
keys.foreach { k =>
187+
map.append(k, unsafeProj(InternalRow(k)))
188+
}
189+
map.optimize()
190+
val row = unsafeProj(InternalRow(0L)).copy()
191+
keys.foreach { k =>
192+
assert(map.getValue(k, row) eq row)
193+
assert(row.getLong(0) === k)
194+
}
195+
assert(map.getValue(Long.MinValue, row) eq null)
196+
map.free()
197+
}
198+
}
199+
155200
test("Spark-14521") {
156201
val ser = new KryoSerializer(
157202
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()

0 commit comments

Comments
 (0)