Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 66 additions & 8 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,19 +659,77 @@ def distinct(self):
return DataFrame(self._jdf.distinct(), self.sql_ctx)

@since(1.3)
def sample(self, withReplacement, fraction, seed=None):
def sample(self, withReplacement=None, fraction=None, seed=None):
"""Returns a sampled subset of this :class:`DataFrame`.

:param withReplacement: Sample with replacement or not (default False).
:param fraction: Fraction of rows to generate, range [0.0, 1.0].
:param seed: Seed for sampling (default a random seed).

.. note:: This is not guaranteed to provide exactly the fraction specified of the total
count of the given :class:`DataFrame`.

>>> df.sample(False, 0.5, 42).count()
2
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this as it looks checked in Scala / Java side:

>>> df.sample(fraction=-0.1).count()
...
pyspark.sql.utils.IllegalArgumentException: u'requirement failed: Sampling fraction (-0.1) must be on interval [0, 1] without replacement'

Copy link
Contributor

@rxin rxin Aug 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd do the check in python, so the error message is more clear. best if the error messages match.

Copy link
Member Author

@HyukjinKwon HyukjinKwon Aug 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm.. wouldn't we better avoid duplicating expression requirement? It looks I should do:

if (withReplacement) {
require(
fraction >= 0.0 - eps,
s"Sampling fraction ($fraction) must be nonnegative with replacement")
} else {
require(
fraction >= 0.0 - eps && fraction <= 1.0 + eps,
s"Sampling fraction ($fraction) must be on interval [0, 1] without replacement")
}

within Python side. I have been thinking of avoiding it if the error message makes sense to Python users (but not the case of exposing non-Pythonic error messages, for example, Java types java.lang.Long in the error message) although I understand it is good to throw an exception ahead before going to JVM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea it'd be better to have python handle the simpler error checking.

seed = seed if seed is not None else random.randint(0, sys.maxsize)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also removed random.randint(0, sys.maxsize) and tried to directly call Scala / Java side one.

rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
.. note:: `fraction` is required and, `withReplacement` and `seed` are optional.

>>> df = spark.range(10)
>>> df.sample(0.5, 3).count()
4
>>> df.sample(fraction=0.5, seed=3).count()
4
>>> df.sample(withReplacement=True, fraction=0.5, seed=3).count()
1
>>> df.sample(1.0).count()
10
>>> df.sample(fraction=1.0).count()
10
>>> df.sample(False, fraction=1.0).count()
10
>>> df.sample("a").count()
Traceback (most recent call last):
...
TypeError:...
>>> df.sample(seed="abc").count()
Traceback (most recent call last):
...
TypeError:...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we don't do the error cases here in doctest, but move them to unit test instead?
also these cases aren't really that meaningfully different to me as an user....?

        >>> df.sample(0.5, 3).count()
 +        4
 +        >>> df.sample(fraction=0.5, seed=3).count()
 +        4
 +        >>> df.sample(1.0).count()
 +        10
 +        >>> df.sample(fraction=1.0).count()
 +        10
 +        >>> df.sample(False, fraction=1.0).count()
 +        10

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes sense! doc tests are examples users can follow

"""

# For the cases below:
# sample(True, 0.5 [, seed])
# sample(True, fraction=0.5 [, seed])
# sample(withReplacement=False, fraction=0.5 [, seed])
is_withReplacement_set = \
type(withReplacement) == bool and isinstance(fraction, float)

# For the case below:
# sample(faction=0.5 [, seed])
is_withReplacement_omitted_kwargs = \
withReplacement is None and isinstance(fraction, float)

# For the case below:
# sample(0.5 [, seed])
is_withReplacement_omitted_args = isinstance(withReplacement, float)

if not (is_withReplacement_set
or is_withReplacement_omitted_kwargs
or is_withReplacement_omitted_args):
argtypes = [
str(type(arg)) for arg in [withReplacement, fraction, seed] if arg is not None]
raise TypeError(
"withReplacement (optional), fraction (required) and seed (optional)"
" should be a bool, float and number; however, "
"got %s." % ", ".join(argtypes))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By this change, all three parameters can be None by default, argtypes seems to be an empty list here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it looks so. Let me try to improve this message.


if is_withReplacement_omitted_args:
if fraction is not None:
seed = fraction
fraction = withReplacement
withReplacement = None

seed = long(seed) if seed is not None else None
args = [arg for arg in [withReplacement, fraction, seed] if arg is not None]
jdf = self._jdf.sample(*args)
return DataFrame(jdf, self.sql_ctx)

@since(1.5)
def sampleBy(self, col, fractions, seed=None):
Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1866,7 +1866,8 @@ class Dataset[T] private[sql](
}

/**
* Returns a new [[Dataset]] by sampling a fraction of rows (without replacement).
* Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
* using a random seed.
*
* @param fraction Fraction of rows to generate, range [0.0, 1.0].
*
Expand Down