Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
This implementation is non-blocking, asynchronously handling the
results of each job and triggering the next job using callbacks on futures.
*/
def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] =
if (results.size >= num || partsScanned >= totalParts) {
Future.successful(results.toSeq)
} else {
Expand All @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
}

val left = num - results.size
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)

val buf = new Array[Array[T]](p.size)
self.context.setCallSite(callSite)
Expand All @@ -109,13 +109,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
p,
(index: Int, data: Array[T]) => buf(index) = data,
Unit)
job.flatMap {_ =>
job.flatMap { _ =>
buf.foreach(results ++= _.take(num - results.size))
continue(partsScanned + p.size)
}
}

new ComplexFutureAction[Seq[T]](continue(0L)(_))
new ComplexFutureAction[Seq[T]](continue(0)(_))
}

/**
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,7 @@ abstract class RDD[T: ClassTag](
} else {
val buf = new ArrayBuffer[T]
val totalParts = this.partitions.length
var partsScanned = 0L
var partsScanned = 0
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
Expand All @@ -1209,7 +1209,7 @@ abstract class RDD[T: ClassTag](
}

val left = num - buf.size
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)

res.foreach(buf ++= _.take(num - buf.size))
Expand Down
4 changes: 4 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)

nums = sc.parallelize(1 to 2, 2)
assert(nums.take(2147483638).size === 2)
assert(nums.takeAsync(2147483638).get.size === 2)
}

test("top with predefined ordering") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ

val buf = new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
var partsScanned = 0L
var partsScanned = 0
while (buf.size < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
Expand All @@ -183,10 +183,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = n - buf.size
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val sc = sqlContext.sparkContext
val res =
sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)
val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)

res.foreach(buf ++= _.take(n - buf.size))
partsScanned += p.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
mapData.toDF().limit(1),
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))

// SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake
checkAnswer(
sqlContext.range(2).limit(2147483638),
Row(0) :: Row(1) :: Nil
)
}

test("except") {
Expand Down
12 changes: 0 additions & 12 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2067,16 +2067,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
}

test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") {
val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 )
rdd.toDF("key").registerTempTable("spark12340")
checkAnswer(
sql("select key from spark12340 limit 2147483638"),
Row(1) :: Row(2) :: Row(3) :: Nil
)
assert(rdd.take(2147483638).size === 3)
assert(rdd.takeAsync(2147483638).get.size === 3)
}

}