Skip to content

Commit ad483fd

Browse files
committed
handle the case with empty RDD when take sample
1 parent 9032f7c commit ad483fd

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ abstract class RDD[T: ClassTag](
310310
* Return a sampled subset of this RDD.
311311
*/
312312
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
313+
if (fraction < Double.MinValue || fraction > Double.MaxValue) {
314+
throw new Exception("Invalid fraction value:" + fraction)
315+
}
313316
if (withReplacement) {
314317
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
315318
} else {
@@ -344,6 +347,10 @@ abstract class RDD[T: ClassTag](
344347
throw new IllegalArgumentException("Negative number of elements requested")
345348
}
346349

350+
if (initialCount == 0) {
351+
return new Array[T](0)
352+
}
353+
347354
if (initialCount > Integer.MAX_VALUE - 1) {
348355
maxSelected = Integer.MAX_VALUE - 1
349356
} else {
@@ -362,7 +369,7 @@ abstract class RDD[T: ClassTag](
362369
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
363370

364371
// If the first sample didn't turn out large enough, keep trying to take samples;
365-
// this shouldn't happen often because we use a big multiplier for thei initial size
372+
// this shouldn't happen often because we use a big multiplier for the initial size
366373
while (samples.length < total) {
367374
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
368375
}

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,10 @@ class RDDSuite extends FunSuite with SharedSparkContext {
457457

458458
test("takeSample") {
459459
val data = sc.parallelize(1 to 100, 2)
460+
val emptySet = data.filter(_ => false)
461+
462+
val sample = emptySet.takeSample(false, 20, 1)
463+
assert(sample.size === 0)
460464
for (seed <- 1 to 5) {
461465
val sample = data.takeSample(withReplacement=false, 20, seed)
462466
assert(sample.size === 20) // Got exactly 20 elements

0 commit comments

Comments
 (0)