Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 48 additions & 18 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -392,29 +392,28 @@ abstract class RDD[T: ClassTag](
this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
}.toArray
}

/**
* Return a fixed-size sampled subset of this RDD in an array
*
* @param withReplacement whether sampling is done with replacement
* @param num size of the returned sample
* @param seed seed for the random number generator
* @return sample of specified size in an array
* Returns a fixed-size sampled subset of this RDD as an RDD
* @param withReplacement - Whether to sample with replacement (boolean)
* @param num - The number of elements to retrieve
* @param seed - A random seed for the randomization
* @return
*/
def takeSample(withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {
def sampleByCount(withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): RDD[T] = {
val numStDev = 10.0

if (num < 0) {
throw new IllegalArgumentException("Negative number of elements requested")
} else if (num == 0) {
return new Array[T](0)
return new EmptyRDD[T](this.sc)
}

val initialCount = this.count()
if (initialCount == 0) {
return new Array[T](0)
return new EmptyRDD[T](this.sc)
}

val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
Expand All @@ -425,26 +424,57 @@ abstract class RDD[T: ClassTag](

val rand = new Random(seed)
if (!withReplacement && num >= initialCount) {
return Utils.randomizeInPlace(this.collect(), rand)
return this
}

// Because sampling is stochastic, compute the sample size needed to ensure a sufficient
// number of samples with 99.99% succss rate
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
withReplacement)

var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
var samples = this.sample(withReplacement, fraction, rand.nextInt())

// If the first sample didn't turn out large enough, keep trying to take samples;
// this shouldn't happen often because we use a big multiplier for the initial size
var numIters = 0
while (samples.length < num) {
var count = samples.count()

// At this point we are guaranteed to have at least "num" samples but we may have more than
// num samples since computeFractionForSample actually yields an upper bound.
// If we have too many samples, drop un-needed ones
while (count < num) {
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
samples = this.sample(withReplacement, fraction, rand.nextInt())
numIters += 1
count = samples.count()
}

// After sampling is complete, we may actually have too many samples. Therefore, as the final
// step, pare down the generated list
if(count > num) {
samples = samples.zipWithIndex().filter(_._2 < num).map(_._1)
}

Utils.randomizeInPlace(samples, rand).take(num)
samples
}

/**
* Return a fixed-size sampled subset of this RDD in an array
*
* @param withReplacement whether sampling is done with replacement
* @param num size of the returned sample
* @param seed seed for the random number generator
* @return sample of specified size in an array
*/
def takeSample(withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {

// To maintain functionality of the previous implementation, randomize the returned
// RDD in place before returning
Utils.randomizeInPlace(sampleByCount(withReplacement, num, seed).collect(), new Random(seed))
}

/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
Expand Down
62 changes: 59 additions & 3 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.rdd

import java.util.Random

import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
Expand Down Expand Up @@ -85,7 +87,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(error(simpleRdd.countApproxDistinct(8, 0), size) < 0.2)
assert(error(simpleRdd.countApproxDistinct(12, 0), size) < 0.1)
}

test("SparkContext.union") {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(sc.union(nums).collect().toList === List(1, 2, 3, 4))
Expand Down Expand Up @@ -549,6 +551,60 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sampled.partitioner === rdd.partitioner)
}
}

test("sampleByCount") {
val count = 10000
val largeSize = 1000000
val smallSize = 50
val data = sc.parallelize(1 to largeSize, 2)
val dataSmall = sc.parallelize(1 to smallSize, 2)

val testCount = 10

val seed = System.currentTimeMillis()
val rand = new Random(seed)

for (i <- 1 to testCount) {
// When sampling without replacement, ensure all elements are distinct and we get the right
// number.

val sampleSize = rand.nextInt(count)
val samples = data.sampleByCount(withReplacement=false, sampleSize, seed)
assert(samples.count() == sampleSize)
assert(samples.distinct().count() == sampleSize)

// *********************************************************************
// When sampling with replacement, ensure we get the right
// number.
val sampleSize2 = rand.nextInt(smallSize) + smallSize
val samples2 = dataSmall.sampleByCount(withReplacement=true, sampleSize2, seed)
assert(samples2.count() == sampleSize2)

// *********************************************************************
// When sampling without replacement and sample more elements than there are in the array
// ensure that the appropriate number of elements are returned
// Ensure that we're requesting more elements than there are in the RDD
val sampleSize3 = rand.nextInt(smallSize) + smallSize
val samples3 = dataSmall.sampleByCount(withReplacement=false, sampleSize3, seed)

assert(samples3.count() == smallSize)

// Values should still be distinct because the original array is still 1:smallCount
assert(samples3.distinct().count() == smallSize)

// *********************************************************************
// When sampling with replacement and sample the entire array for a large count
// ensure that all elements are not distinct
val sampleSize4 = count + rand.nextInt(count)
val samples4 = data.sampleByCount(withReplacement=true, sampleSize4, seed)

assert(samples4.count() == sampleSize4)

// Chance of getting all distinct elements is astronomically low, confirm that this doesnt
// happen
assert(samples4.distinct().count() < sampleSize4)
}
}

test("takeSample") {
val n = 1000000
Expand Down Expand Up @@ -579,13 +635,13 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
val sample = data.takeSample(withReplacement=true, num=n)
assert(sample.size === n) // Got exactly 100 elements
assert(sample.size === n) // Got exactly 1000000 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
Expand Down