From d3dd60e476a7ae1b6fd3d58d12bc6bcb79bf17ee Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 25 Jun 2024 20:43:47 +0800 Subject: [PATCH 1/3] add check --- .../apache/spark/sql/catalyst/expressions/UnsafeRow.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 6325ba68af5b..0fc5a64aa038 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -155,6 +155,15 @@ public UnsafeRow() {} public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; assert sizeInBytes % 8 == 0 : "sizeInBytes (" + sizeInBytes + ") should be a multiple of 8"; + if (baseObject instanceof byte[] bytes) { + int offsetInByteArray = (int) (baseOffset - Platform.BYTE_ARRAY_OFFSET); + if (bytes.length < offsetInByteArray + sizeInBytes) { + throw new ArrayIndexOutOfBoundsException( + "byte array length: " + bytes.length + + ", offset: " + offsetInByteArray + ", size: " + sizeInBytes + ); + } + } this.baseObject = baseObject; this.baseOffset = baseOffset; this.sizeInBytes = sizeInBytes; From b88dcba096062402fbba84d480be192aa367f8fc Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Tue, 25 Jun 2024 21:10:00 +0800 Subject: [PATCH 2/3] add ut --- .../sql/catalyst/expressions/UnsafeRow.java | 3 ++- .../org/apache/spark/sql/UnsafeRowSuite.scala | 27 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 0fc5a64aa038..fd6feea521ae 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -157,7 +157,8 @@ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { assert sizeInBytes % 8 == 0 : "sizeInBytes (" + sizeInBytes + ") should be a multiple of 8"; if (baseObject instanceof byte[] bytes) { int offsetInByteArray = (int) (baseOffset - Platform.BYTE_ARRAY_OFFSET); - if (bytes.length < offsetInByteArray + sizeInBytes) { + if (offsetInByteArray < 0 || sizeInBytes < 0 || + bytes.length < offsetInByteArray + sizeInBytes) { throw new ArrayIndexOutOfBoundsException( "byte array length: " + bytes.length + ", offset: " + offsetInByteArray + ", size: " + sizeInBytes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 9daa69ce9f15..1643dfe25f4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -188,4 +188,31 @@ class UnsafeRowSuite extends SparkFunSuite { unsafeRow.setDecimal(0, d2, 38) assert(unsafeRow.getDecimal(0, 38, 18) === null) } + + test("SPARK-48713: throw ArrayIndexOutOfBoundsException for illegal UnsafeRow.pointTo") { + val emptyRow = UnsafeRow.createFromByteArray(64, 2) + val byteArray = new Array[Byte](64) + + // Out of bounds + var errorMsg = intercept[ArrayIndexOutOfBoundsException] { + emptyRow.pointTo(byteArray, Platform.BYTE_ARRAY_OFFSET + 50, 32) + }.getMessage + assert(errorMsg.contains("byte array length: 64, offset: 50, size: 32")) + + // Negative size + errorMsg = intercept[ArrayIndexOutOfBoundsException] { + emptyRow.pointTo(byteArray, Platform.BYTE_ARRAY_OFFSET + 50, -32) + }.getMessage + assert(errorMsg.contains("byte array length: 64, offset: 50, size: -32")) + + // Negative offset + errorMsg = intercept[ArrayIndexOutOfBoundsException] { + emptyRow.pointTo(byteArray, -5, 32) + }.getMessage + assert( + errorMsg.contains( + s"byte array length: 64, offset: ${-5 - Platform.BYTE_ARRAY_OFFSET}, size: 32" + ) + ) + } } From 4244864f1b3f8f0bef9039dfdd468b8adeeea1a4 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Wed, 26 Jun 2024 10:35:52 +0800 Subject: [PATCH 3/3] address comment --- .../sql/catalyst/expressions/UnsafeRow.java | 9 ++++--- .../org/apache/spark/sql/UnsafeRowSuite.scala | 25 +++++++++++++------ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index fd6feea521ae..8741c206f2bb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -21,12 +21,14 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.util.Map; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.SparkIllegalArgumentException; import org.apache.spark.SparkUnsupportedOperationException; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.types.*; @@ -159,9 +161,10 @@ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { int offsetInByteArray = (int) (baseOffset - Platform.BYTE_ARRAY_OFFSET); if (offsetInByteArray < 0 || sizeInBytes < 0 || bytes.length < offsetInByteArray + sizeInBytes) { - throw new ArrayIndexOutOfBoundsException( - "byte array length: " + bytes.length + - ", offset: " + offsetInByteArray + ", size: " + sizeInBytes + throw new SparkIllegalArgumentException( + "INTERNAL_ERROR", + Map.of("message", "Invalid byte array backed UnsafeRow: byte array length=" + + bytes.length + ", offset=" + offsetInByteArray + ", byte size=" + sizeInBytes) ); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 1643dfe25f4f..18a6c538e0a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite, SparkIllegalArgumentException} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} @@ -189,29 +189,38 @@ class UnsafeRowSuite extends SparkFunSuite { assert(unsafeRow.getDecimal(0, 38, 18) === null) } - test("SPARK-48713: throw ArrayIndexOutOfBoundsException for illegal UnsafeRow.pointTo") { + test("SPARK-48713: throw SparkIllegalArgumentException for illegal UnsafeRow.pointTo") { val emptyRow = UnsafeRow.createFromByteArray(64, 2) val byteArray = new Array[Byte](64) // Out of bounds - var errorMsg = intercept[ArrayIndexOutOfBoundsException] { + var errorMsg = intercept[SparkIllegalArgumentException] { emptyRow.pointTo(byteArray, Platform.BYTE_ARRAY_OFFSET + 50, 32) }.getMessage - assert(errorMsg.contains("byte array length: 64, offset: 50, size: 32")) + assert( + errorMsg.contains( + "Invalid byte array backed UnsafeRow: byte array length=64, offset=50, byte size=32" + ) + ) // Negative size - errorMsg = intercept[ArrayIndexOutOfBoundsException] { + errorMsg = intercept[SparkIllegalArgumentException] { emptyRow.pointTo(byteArray, Platform.BYTE_ARRAY_OFFSET + 50, -32) }.getMessage - assert(errorMsg.contains("byte array length: 64, offset: 50, size: -32")) + assert( + errorMsg.contains( + "Invalid byte array backed UnsafeRow: byte array length=64, offset=50, byte size=-32" + ) + ) // Negative offset - errorMsg = intercept[ArrayIndexOutOfBoundsException] { + errorMsg = intercept[SparkIllegalArgumentException] { emptyRow.pointTo(byteArray, -5, 32) }.getMessage assert( errorMsg.contains( - s"byte array length: 64, offset: ${-5 - Platform.BYTE_ARRAY_OFFSET}, size: 32" + s"Invalid byte array backed UnsafeRow: byte array length=64, " + + s"offset=${-5 - Platform.BYTE_ARRAY_OFFSET}, byte size=32" ) ) }