-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21779][PYTHON] Simpler DataFrame.sample API in Python #18999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -659,19 +659,77 @@ def distinct(self): | |
| return DataFrame(self._jdf.distinct(), self.sql_ctx) | ||
|
|
||
| @since(1.3) | ||
| def sample(self, withReplacement, fraction, seed=None): | ||
| def sample(self, withReplacement=None, fraction=None, seed=None): | ||
| """Returns a sampled subset of this :class:`DataFrame`. | ||
|
|
||
| :param withReplacement: Sample with replacement or not (default False). | ||
| :param fraction: Fraction of rows to generate, range [0.0, 1.0]. | ||
| :param seed: Seed for sampling (default a random seed). | ||
|
|
||
| .. note:: This is not guaranteed to provide exactly the fraction specified of the total | ||
| count of the given :class:`DataFrame`. | ||
|
|
||
| >>> df.sample(False, 0.5, 42).count() | ||
| 2 | ||
| """ | ||
| assert fraction >= 0.0, "Negative fraction value: %s" % fraction | ||
| seed = seed if seed is not None else random.randint(0, sys.maxsize) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also removed |
||
| rdd = self._jdf.sample(withReplacement, fraction, long(seed)) | ||
| return DataFrame(rdd, self.sql_ctx) | ||
| .. note:: `fraction` is required and, `withReplacement` and `seed` are optional. | ||
|
|
||
| >>> df = spark.range(10) | ||
| >>> df.sample(0.5, 3).count() | ||
| 4 | ||
| >>> df.sample(fraction=0.5, seed=3).count() | ||
| 4 | ||
| >>> df.sample(withReplacement=True, fraction=0.5, seed=3).count() | ||
| 1 | ||
| >>> df.sample(1.0).count() | ||
| 10 | ||
| >>> df.sample(fraction=1.0).count() | ||
| 10 | ||
| >>> df.sample(False, fraction=1.0).count() | ||
| 10 | ||
| >>> df.sample("a").count() | ||
| Traceback (most recent call last): | ||
| ... | ||
| TypeError:... | ||
| >>> df.sample(seed="abc").count() | ||
| Traceback (most recent call last): | ||
| ... | ||
| TypeError:... | ||
|
||
| """ | ||
|
|
||
| # For the cases below: | ||
| # sample(True, 0.5 [, seed]) | ||
| # sample(True, fraction=0.5 [, seed]) | ||
| # sample(withReplacement=False, fraction=0.5 [, seed]) | ||
| is_withReplacement_set = \ | ||
| type(withReplacement) == bool and isinstance(fraction, float) | ||
|
|
||
| # For the case below: | ||
| # sample(faction=0.5 [, seed]) | ||
| is_withReplacement_omitted_kwargs = \ | ||
| withReplacement is None and isinstance(fraction, float) | ||
|
|
||
| # For the case below: | ||
| # sample(0.5 [, seed]) | ||
| is_withReplacement_omitted_args = isinstance(withReplacement, float) | ||
|
|
||
| if not (is_withReplacement_set | ||
| or is_withReplacement_omitted_kwargs | ||
| or is_withReplacement_omitted_args): | ||
| argtypes = [ | ||
| str(type(arg)) for arg in [withReplacement, fraction, seed] if arg is not None] | ||
| raise TypeError( | ||
| "withReplacement (optional), fraction (required) and seed (optional)" | ||
| " should be a bool, float and number; however, " | ||
| "got [%s]." % ", ".join(argtypes)) | ||
|
|
||
| if is_withReplacement_omitted_args: | ||
| if fraction is not None: | ||
| seed = fraction | ||
| fraction = withReplacement | ||
| withReplacement = None | ||
|
|
||
| seed = long(seed) if seed is not None else None | ||
| args = [arg for arg in [withReplacement, fraction, seed] if arg is not None] | ||
| jdf = self._jdf.sample(*args) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| @since(1.5) | ||
| def sampleBy(self, col, fractions, seed=None): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this as it looks checked in Scala / Java side:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd do the check in python, so the error message is more clear. best if the error messages match.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm.. wouldn't we better avoid duplicating expression requirement? It looks I should do:
spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
Lines 714 to 722 in 5ad1796
within Python side. I have been thinking of avoiding it if the error message makes sense to Python users (but not the case of exposing non-Pythonic error messages, for example, Java types
java.lang.Longin the error message) although I understand it is good to throw an exception ahead before going to JVM.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea it'd be better to have python handle the simpler error checking.