diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 537ef244b7e81..6a52a5b0e0664 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -35,6 +35,7 @@ final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + // buffer is guarantee to be word-aligned since UnsafeRow assumes each field is word-aligned. private byte[] buffer; private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; @@ -52,7 +53,8 @@ final class BufferHolder { "too many fields (number of fields: " + row.numFields() + ")"); } this.fixedSize = bitsetWidthInBytes + 8 * row.numFields(); - this.buffer = new byte[fixedSize + initialSize]; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(fixedSize + initialSize); + this.buffer = new byte[roundedSize]; this.row = row; this.row.pointTo(buffer, buffer.length); } @@ -61,8 +63,12 @@ final class BufferHolder { * Grows the buffer by at least neededSize and points the row to the buffer. */ void grow(int neededSize) { + if (neededSize < 0) { + throw new IllegalArgumentException( + "Cannot grow BufferHolder by size " + neededSize + " because the size is negative"); + } if (neededSize > ARRAY_MAX - totalSize()) { - throw new UnsupportedOperationException( + throw new IllegalArgumentException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } @@ -70,7 +76,8 @@ void grow(int neededSize) { if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; - final byte[] tmp = new byte[newLength]; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(newLength); + final byte[] tmp = new byte[roundedSize]; Platform.copyMemory( buffer, Platform.BYTE_ARRAY_OFFSET, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala index 85682cf6ea670..d2862c8f41d1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers} import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.deploy.SparkSubmitSuite @@ -39,7 +39,7 @@ class BufferHolderSparkSubmitSuite val argsForSparkSubmit = Seq( "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"), "--name", "SPARK-22222", - "--master", "local-cluster[2,1,1024]", + "--master", "local-cluster[1,1,4096]", "--driver-memory", "4g", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", @@ -49,28 +49,36 @@ class BufferHolderSparkSubmitSuite } } -object BufferHolderSparkSubmitSuite { +object BufferHolderSparkSubmitSuite extends Assertions { def main(args: Array[String]): Unit = { val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - val holder = new BufferHolder(new UnsafeRow(1000)) + val unsafeRow = new UnsafeRow(1000) + val holder = new BufferHolder(unsafeRow) holder.reset() - holder.grow(roundToWord(ARRAY_MAX / 2)) - holder.reset() - holder.grow(roundToWord(ARRAY_MAX / 2 + 8)) + assert(intercept[IllegalArgumentException] { + holder.grow(-1) + }.getMessage.contains("because the size is negative")) - holder.reset() - holder.grow(roundToWord(Integer.MAX_VALUE / 2)) + // while to reuse a buffer may happen, this test checks whether the buffer can be grown + holder.grow(ARRAY_MAX / 2) + assert(unsafeRow.getSizeInBytes % 8 == 0) - holder.reset() - holder.grow(roundToWord(Integer.MAX_VALUE)) - } + holder.grow(ARRAY_MAX / 2 + 7) + assert(unsafeRow.getSizeInBytes % 8 == 0) + + holder.grow(Integer.MAX_VALUE / 2) + assert(unsafeRow.getSizeInBytes % 8 == 0) + + holder.grow(ARRAY_MAX - holder.totalSize()) + assert(unsafeRow.getSizeInBytes % 8 == 0) - private def roundToWord(len: Int): Int = { - ByteArrayMethods.roundNumberOfBytesToNearestWord(len) + assert(intercept[IllegalArgumentException] { + holder.grow(ARRAY_MAX + 1 - holder.totalSize()) + }.getMessage.contains("because the size after growing")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala index c7c386b5b838a..4e0f903a030aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala @@ -23,17 +23,15 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow class BufferHolderSuite extends SparkFunSuite { test("SPARK-16071 Check the size limit to avoid integer overflow") { - var e = intercept[UnsupportedOperationException] { + assert(intercept[UnsupportedOperationException] { new BufferHolder(new UnsafeRow(Int.MaxValue / 8)) - } - assert(e.getMessage.contains("too many fields")) + }.getMessage.contains("too many fields")) val holder = new BufferHolder(new UnsafeRow(1000)) holder.reset() holder.grow(1000) - e = intercept[UnsupportedOperationException] { + assert(intercept[IllegalArgumentException] { holder.grow(Integer.MAX_VALUE) - } - assert(e.getMessage.contains("exceeds size limitation")) + }.getMessage.contains("exceeds size limitation")) } }