Skip to content

Commit 3469ec6

Browse files
wayneguowyaooqinn
andcommitted
[SPARK-48656][CORE] Do a length check and throw COLLECTION_SIZE_LIMIT_EXCEEDED error in CartesianRDD.getPartitions
### What changes were proposed in this pull request? This pr aims to optimize the error message by doing a length check and throw `COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE` error in `CartesianRDD.getPartitions`. ### Why are the changes needed? Optimize the error prompts for Spark users. ### Does this PR introduce _any_ user-facing change? Yes, this PR changes user-facing error class and message. ### How was this patch tested? Add a new test case in `RDDSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47019 from wayneguow/SPARK-48656. Lead-authored-by: Wei Guo <[email protected]> Co-authored-by: Kent Yao <[email protected]> Signed-off-by: Kent Yao <[email protected]>
1 parent b99bb00 commit 3469ec6

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

core/src/main/scala/org/apache/spark/errors/SparkCoreErrors.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,15 @@ private[spark] object SparkCoreErrors {
501501
"configVal" -> toConfVal(FALLBACK_COMPRESSION_CODEC)))
502502
}
503503

504+
def tooManyArrayElementsError(numElements: Long, maxRoundedArrayLength: Int): Throwable = {
505+
new SparkIllegalArgumentException(
506+
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE",
507+
messageParameters = Map(
508+
"numberOfElements" -> numElements.toString,
509+
"maxRoundedArrayLength" -> maxRoundedArrayLength.toString)
510+
)
511+
}
512+
504513
private def quoteByDefault(elem: String): String = {
505514
"\"" + elem + "\""
506515
}

core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
2222
import scala.reflect.ClassTag
2323

2424
import org.apache.spark._
25+
import org.apache.spark.errors.SparkCoreErrors
2526
import org.apache.spark.util.Utils
2627

2728
private[spark]
@@ -57,7 +58,11 @@ class CartesianRDD[T: ClassTag, U: ClassTag](
5758

5859
override def getPartitions: Array[Partition] = {
5960
// create the cross product split
60-
val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length)
61+
val partitionNum: Long = numPartitionsInRdd2.toLong * rdd1.partitions.length
62+
if (partitionNum > Int.MaxValue) {
63+
throw SparkCoreErrors.tooManyArrayElementsError(partitionNum, Int.MaxValue)
64+
}
65+
val array = new Array[Partition](partitionNum.toInt)
6166
for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) {
6267
val idx = s1.index * numPartitionsInRdd2 + s2.index
6368
array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index)

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
914914
assert(c.cartesian[Int](a).collect().toList.sorted === c_cartesian_a)
915915
}
916916

917+
test("SPARK-48656: number of cartesian partitions overflow") {
918+
val numSlices: Int = 65536
919+
val rdd1 = sc.parallelize(Seq(1, 2, 3), numSlices = numSlices)
920+
val rdd2 = sc.parallelize(Seq(1, 2, 3), numSlices = numSlices)
921+
checkError(
922+
exception = intercept[SparkIllegalArgumentException] {
923+
rdd1.cartesian(rdd2).partitions
924+
},
925+
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE",
926+
sqlState = "54000",
927+
parameters = Map(
928+
"numberOfElements" -> (numSlices.toLong * numSlices.toLong).toString,
929+
"maxRoundedArrayLength" -> Int.MaxValue.toString)
930+
)
931+
}
932+
917933
test("intersection") {
918934
val all = sc.parallelize(1 to 10)
919935
val evens = sc.parallelize(2 to 10 by 2)

0 commit comments

Comments
 (0)