-
Notifications
You must be signed in to change notification settings - Fork 29k
SPARK-1438 RDD.sample() make seed param optional #477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
0c247db
69619c6
8d05b1a
b9ebfe2
07bb06e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…er. python: use a separate instance of Random instead of seeding language api global Random instance.
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -467,11 +467,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { | |
| val data = sc.parallelize(1 to 100, 2) | ||
|
|
||
| for (num <- List(5,20,100)) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put spaces after the commas here |
||
| val sample = data.takeSample(withReplacement=false, num=num) | ||
| val sample = data.takeSample(withReplacement=false, num=num) | ||
| assert(sample.size === num) // Got exactly num elements | ||
| assert(sample.toSet.size === num) // Elements are distinct | ||
| assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
| } | ||
| } | ||
| for (seed <- 1 to 5) { | ||
| val sample = data.takeSample(withReplacement=false, 20, seed) | ||
| assert(sample.size === 20) // Got exactly 20 elements | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,7 @@ | |
| from threading import Thread | ||
| import warnings | ||
| import heapq | ||
| import random | ||
| from random import Random | ||
|
|
||
| from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ | ||
| BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long | ||
|
|
@@ -382,11 +382,11 @@ def takeSample(self, withReplacement, num, seed=None): | |
| # If the first sample didn't turn out large enough, keep trying to take samples; | ||
| # this shouldn't happen often because we use a big multiplier for their initial size. | ||
| # See: scala/spark/RDD.scala | ||
| random.seed(seed) | ||
| rand = Random(seed) | ||
| while len(samples) < total: | ||
| samples = self.sample(withReplacement, fraction, random.randint(0,sys.maxint)).collect() | ||
| samples = self.sample(withReplacement, fraction, rand.randint(0,sys.maxint)).collect() | ||
|
|
||
| sampler = RDDSampler(withReplacement, fraction, random.randint(0,sys.maxint)) | ||
| sampler = RDDSampler(withReplacement, fraction, rand.randint(0,sys.maxint)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put spaces after the comma here and in other instances of |
||
| sampler.shuffle(samples) | ||
| return samples[0:total] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ def __init__(self, withReplacement, fraction, seed=None): | |
| print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling." | ||
| self._use_numpy = False | ||
|
|
||
| self._seed = seed | ||
| self._seed = seed if seed is not None else random.randint(0,sys.maxint) | ||
| self._withReplacement = withReplacement | ||
| self._fraction = fraction | ||
| self._random = None | ||
|
|
@@ -38,17 +38,15 @@ def initRandomGenerator(self, split): | |
| if self._use_numpy: | ||
| import numpy | ||
| self._random = numpy.random.RandomState(self._seed) | ||
| for _ in range(0, split): | ||
| # discard the next few values in the sequence to have a | ||
| # different seed for the different splits | ||
| self._random.randint(sys.maxint) | ||
| else: | ||
| import random | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since, we have imported random at the beginning. This line is unnecessary. |
||
| random.seed(self._seed) | ||
| for _ in range(0, split): | ||
| # discard the next few values in the sequence to have a | ||
| # different seed for the different splits | ||
| random.randint(0, sys.maxint) | ||
| self._random = random.Random(self._seed) | ||
|
|
||
| for _ in range(0, split): | ||
| # discard the next few values in the sequence to have a | ||
| # different seed for the different splits | ||
| self._random.randint(0, sys.maxint) | ||
|
|
||
| self._split = split | ||
| self._rand_initialized = True | ||
|
|
||
|
|
@@ -59,7 +57,7 @@ def getUniformSample(self, split): | |
| if self._use_numpy: | ||
| return self._random.random_sample() | ||
| else: | ||
| return random.uniform(0.0, 1.0) | ||
| return self._random.uniform(0.0, 1.0) | ||
|
|
||
| def getPoissonSample(self, split, mean): | ||
| if not self._rand_initialized or split != self._split: | ||
|
|
@@ -73,26 +71,26 @@ def getPoissonSample(self, split, mean): | |
| num_arrivals = 1 | ||
| cur_time = 0.0 | ||
|
|
||
| cur_time += random.expovariate(mean) | ||
| cur_time += self._random.expovariate(mean) | ||
|
|
||
| if cur_time > 1.0: | ||
| return 0 | ||
|
|
||
| while(cur_time <= 1.0): | ||
| cur_time += random.expovariate(mean) | ||
| cur_time += self._random.expovariate(mean) | ||
| num_arrivals += 1 | ||
|
|
||
| return (num_arrivals - 1) | ||
|
|
||
| def shuffle(self, vals): | ||
| if self._random == None or split != self._split: | ||
| if self._random == None: | ||
| self.initRandomGenerator(0) # this should only ever called on the master so | ||
| # the split does not matter | ||
|
|
||
| if self._use_numpy: | ||
| self._random.shuffle(vals) | ||
| else: | ||
| random.shuffle(vals, self._random) | ||
| self._random.shuffle(vals, self._random.random) | ||
|
|
||
| def func(self, split, iterator): | ||
| if self._withReplacement: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to say what the seed defaults to here since users won't understand it; just say
@param seed random seedand they can guess that if you don't specify it, we will choose one