Skip to content

Commit 8883401

Browse files
cxzl25cloud-fan
authored andcommitted
[SPARK-24257][SQL] LongToUnsafeRowMap calculate the new size may be wrong
LongToUnsafeRowMap has a mistake when growing its page array: it blindly grows to `oldSize * 2`, while the new record may be larger than `oldSize * 2`. Then we may have a malformed UnsafeRow when querying this map, whose actual data is smaller than its declared size, and the data is corrupted. Author: sychen <[email protected]> Closes #21311 from cxzl25/fix_LongToUnsafeRowMap_page_size.
1 parent 230f144 commit 8883401

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
557557
def append(key: Long, row: UnsafeRow): Unit = {
558558
val sizeInBytes = row.getSizeInBytes
559559
if (sizeInBytes >= (1 << SIZE_BITS)) {
560-
sys.error("Does not support row that is larger than 256M")
560+
throw new UnsupportedOperationException("Does not support row that is larger than 256M")
561561
}
562562

563563
if (key < minKey) {
@@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
567567
maxKey = key
568568
}
569569

570-
// There is 8 bytes for the pointer to next value
571-
if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) {
572-
val used = page.length
573-
if (used >= (1 << 30)) {
574-
sys.error("Can not build a HashedRelation that is larger than 8G")
575-
}
576-
ensureAcquireMemory(used * 8L * 2)
577-
val newPage = new Array[Long](used * 2)
578-
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
579-
cursor - Platform.LONG_ARRAY_OFFSET)
580-
page = newPage
581-
freeMemory(used * 8L)
582-
}
570+
grow(row.getSizeInBytes)
583571

584572
// copy the bytes of UnsafeRow
585573
val offset = cursor
@@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
615603
growArray()
616604
} else if (numKeys > array.length / 2 * 0.75) {
617605
// The fill ratio should be less than 0.75
618-
sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
606+
throw new UnsupportedOperationException(
607+
"Cannot build HashedRelation with more than 1/3 billions unique keys")
619608
}
620609
}
621610
} else {
@@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
626615
}
627616
}
628617

618+
private def grow(inputRowSize: Int): Unit = {
619+
// There is 8 bytes for the pointer to next value
620+
val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8
621+
if (neededNumWords > page.length) {
622+
if (neededNumWords > (1 << 30)) {
623+
throw new UnsupportedOperationException(
624+
"Can not build a HashedRelation that is larger than 8G")
625+
}
626+
val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30))
627+
ensureAcquireMemory(newNumWords * 8L)
628+
val newPage = new Array[Long](newNumWords.toInt)
629+
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
630+
cursor - Platform.LONG_ARRAY_OFFSET)
631+
val used = page.length
632+
page = newPage
633+
freeMemory(used * 8L)
634+
}
635+
}
636+
629637
private def growArray(): Unit = {
630638
var old_array = array
631639
val n = array.length

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.serializer.KryoSerializer
2828
import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.expressions._
3030
import org.apache.spark.sql.test.SharedSQLContext
31-
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
31+
import org.apache.spark.sql.types._
3232
import org.apache.spark.unsafe.map.BytesToBytesMap
3333
import org.apache.spark.unsafe.types.UTF8String
3434
import org.apache.spark.util.collection.CompactBuffer
@@ -254,6 +254,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
254254
map.free()
255255
}
256256

257+
test("SPARK-24257: insert big values into LongToUnsafeRowMap") {
258+
val taskMemoryManager = new TaskMemoryManager(
259+
new StaticMemoryManager(
260+
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
261+
Long.MaxValue,
262+
Long.MaxValue,
263+
1),
264+
0)
265+
val unsafeProj = UnsafeProjection.create(Array[DataType](StringType))
266+
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
267+
268+
val key = 0L
269+
// the page array is initialized with length 1 << 17 (1M bytes),
270+
// so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug
271+
val bigStr = UTF8String.fromString("x" * (1 << 19))
272+
273+
map.append(key, unsafeProj(InternalRow(bigStr)))
274+
map.optimize()
275+
276+
val resultRow = new UnsafeRow(1)
277+
assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr)
278+
map.free()
279+
}
280+
257281
test("Spark-14521") {
258282
val ser = new KryoSerializer(
259283
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()

0 commit comments

Comments
 (0)