diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 6325ba68af5b..8741c206f2bb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -21,12 +21,14 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.util.Map; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.SparkIllegalArgumentException; import org.apache.spark.SparkUnsupportedOperationException; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.types.*; @@ -155,6 +157,17 @@ public UnsafeRow() {} public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; assert sizeInBytes % 8 == 0 : "sizeInBytes (" + sizeInBytes + ") should be a multiple of 8"; + if (baseObject instanceof byte[] bytes) { + int offsetInByteArray = (int) (baseOffset - Platform.BYTE_ARRAY_OFFSET); + if (offsetInByteArray < 0 || sizeInBytes < 0 || + bytes.length < offsetInByteArray + sizeInBytes) { + throw new SparkIllegalArgumentException( + "INTERNAL_ERROR", + Map.of("message", "Invalid byte array backed UnsafeRow: byte array length=" + + bytes.length + ", offset=" + offsetInByteArray + ", byte size=" + sizeInBytes) + ); + } + } this.baseObject = baseObject; this.baseOffset = baseOffset; this.sizeInBytes = sizeInBytes; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 9daa69ce9f15..18a6c538e0a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite, SparkIllegalArgumentException} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} @@ -188,4 +188,40 @@ class UnsafeRowSuite extends SparkFunSuite { unsafeRow.setDecimal(0, d2, 38) assert(unsafeRow.getDecimal(0, 38, 18) === null) } + + test("SPARK-48713: throw SparkIllegalArgumentException for illegal UnsafeRow.pointTo") { + val emptyRow = UnsafeRow.createFromByteArray(64, 2) + val byteArray = new Array[Byte](64) + + // Out of bounds + var errorMsg = intercept[SparkIllegalArgumentException] { + emptyRow.pointTo(byteArray, Platform.BYTE_ARRAY_OFFSET + 50, 32) + }.getMessage + assert( + errorMsg.contains( + "Invalid byte array backed UnsafeRow: byte array length=64, offset=50, byte size=32" + ) + ) + + // Negative size + errorMsg = intercept[SparkIllegalArgumentException] { + emptyRow.pointTo(byteArray, Platform.BYTE_ARRAY_OFFSET + 50, -32) + }.getMessage + assert( + errorMsg.contains( + "Invalid byte array backed UnsafeRow: byte array length=64, offset=50, byte size=-32" + ) + ) + + // Negative offset + errorMsg = intercept[SparkIllegalArgumentException] { + emptyRow.pointTo(byteArray, -5, 32) + }.getMessage + assert( + errorMsg.contains( + s"Invalid byte array backed UnsafeRow: byte array length=64, " + + s"offset=${-5 - Platform.BYTE_ARRAY_OFFSET}, byte size=32" + ) + ) + } }