Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix overflow in LongToUnsafeRowMap
  • Loading branch information
Davies Liu committed Aug 2, 2016
commit 97027f086f1e7a977832a1acfb9ebe29250a28ea
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
if (isDense) {
val idx = (key - minKey).toInt
if (idx >= 0 && key <= maxKey && array(idx) > 0) {
val idx = (key - minKey).toInt // could overflow
if (key >= minKey && key <= maxKey && array(idx) > 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I see where this is going but I think this doesn't totally eliminate the problem. key - minKey could still overflow such that the int is positive and even >= minKey. It seems like we need to test the keys against each other as longs, and only then covert to an int to index into the array?

Copy link
Contributor Author

@davies davies Aug 2, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having both key >= minKey and key <= maxKey could make sure that there is no overflow (because we already make sure that the range between minKey and maxKey is smaller than Int.MaxValue), then we can safely use (key - minKey).toInt

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah OK this should be OK. I might suggest the following as a little simpler, but whatever:

if (key >= minKey && key <= maxKey) {
  val value = array((key - minKey).toInt)
  if (value > 0) {
    return getRow(value, resultRow)
  }
}

?

return getRow(array(idx), resultRow)
}
} else {
Expand Down Expand Up @@ -497,8 +497,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
if (isDense) {
val idx = (key - minKey).toInt
if (idx >=0 && key <= maxKey && array(idx) > 0) {
val idx = (key - minKey).toInt // could overflow
if (key >= minKey && key <= maxKey && array(idx) > 0) {
return valueIter(array(idx), resultRow)
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,51 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
}
}

test("LongToUnsafeRowMap with very wide range") {
val taskMemoryManager = new TaskMemoryManager(
new StaticMemoryManager(
new SparkConf().set("spark.memory.offHeap.enabled", "false"),
Long.MaxValue,
Long.MaxValue,
1),
0)
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))

{
// SPARK-16740
val keys = Seq(0L, Long.MaxValue, Long.MaxValue)
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
keys.foreach { k =>
map.append(k, unsafeProj(InternalRow(k)))
}
map.optimize()
val row = unsafeProj(InternalRow(0L)).copy()
keys.foreach { k =>
assert(map.getValue(k, row) eq row)
assert(row.getLong(0) === k)
}
map.free()
}


{
// SPARK-16802
val keys = Seq(Long.MaxValue, Long.MaxValue - 10)
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
keys.foreach { k =>
map.append(k, unsafeProj(InternalRow(k)))
}
map.optimize()
val row = unsafeProj(InternalRow(0L)).copy()
keys.foreach { k =>
assert(map.getValue(k, row) eq row)
assert(row.getLong(0) === k)
}
assert(map.getValue(Long.MinValue, row) eq null)
map.free()
}
}

test("Spark-14521") {
val ser = new KryoSerializer(
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
Expand Down