Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,15 @@ private[spark] object SparkCoreErrors {
"configVal" -> toConfVal(FALLBACK_COMPRESSION_CODEC)))
}

def tooManyArrayElementsError(numElements: Long, maxRoundedArrayLength: Int): Throwable = {
new SparkIllegalArgumentException(
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE",
messageParameters = Map(
"numberOfElements" -> numElements.toString,
"maxRoundedArrayLength" -> maxRoundedArrayLength.toString)
)
}

private def quoteByDefault(elem: String): String = {
"\"" + elem + "\""
}
Expand Down
7 changes: 6 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.util.Utils

private[spark]
Expand Down Expand Up @@ -57,7 +58,11 @@ class CartesianRDD[T: ClassTag, U: ClassTag](

override def getPartitions: Array[Partition] = {
// create the cross product split
val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length)
val partitionNum: Long = numPartitionsInRdd2.toLong * rdd1.partitions.length
if (partitionNum > Int.MaxValue) {
throw SparkCoreErrors.tooManyArrayElementsError(partitionNum, Int.MaxValue)
}
val array = new Array[Partition](partitionNum.toInt)
for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) {
val idx = s1.index * numPartitionsInRdd2 + s2.index
array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index)
Expand Down
16 changes: 16 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
assert(c.cartesian[Int](a).collect().toList.sorted === c_cartesian_a)
}

test("SPARK-48656: number of cartesian partitions overflow") {
val numSlices: Int = 65536
val rdd1 = sc.parallelize(Seq(1, 2, 3), numSlices = numSlices)
val rdd2 = sc.parallelize(Seq(1, 2, 3), numSlices = numSlices)
checkError(
exception = intercept[SparkIllegalArgumentException] {
rdd1.cartesian(rdd2).partitions
},
errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE",
sqlState = "54000",
parameters = Map(
"numberOfElements" -> (numSlices.toLong * numSlices.toLong).toString,
"maxRoundedArrayLength" -> Int.MaxValue.toString)
)
}

test("intersection") {
val all = sc.parallelize(1 to 10)
val evens = sc.parallelize(2 to 10 by 2)
Expand Down