Skip to content

Commit 6000328

Browse files
committed
added python api and fixed test
1 parent 3c11d1b commit 6000328

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,14 +425,30 @@ def distinct(self):
425425
def sample(self, withReplacement, fraction, seed=None):
426426
"""Returns a sampled subset of this :class:`DataFrame`.
427427
428-
>>> df.sample(False, 0.5, 97).count()
428+
>>> df.sample(False, 0.5, 42).count()
429429
1
430430
"""
431431
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
432432
seed = seed if seed is not None else random.randint(0, sys.maxsize)
433433
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
434434
return DataFrame(rdd, self.sql_ctx)
435435

436+
def randomSplit(self, weights, seed=None):
437+
"""Randomly splits this :class:`DataFrame` with the provided weights.
438+
439+
>>> splits = df4.randomSplit([1.0, 2.0], 24)
440+
>>> splits[0].count()
441+
1
442+
443+
>>> splits[1].count()
444+
3
445+
"""
446+
for w in weights:
447+
assert w >= 0.0, "Negative weight value: %s" % w
448+
seed = seed if seed is not None else random.randint(0, sys.maxsize)
449+
rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
450+
return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
451+
436452
@property
437453
def dtypes(self):
438454
"""Returns all column names and their data types as a list.

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ class DataFrame private[sql](
728728
/**
729729
* Randomly splits this [[DataFrame]] with the provided weights.
730730
*
731-
* @param weights weights for splits, will be normalized if they don't sum to 1
731+
* @param weights weights for splits, will be normalized if they don't sum to 1.
732732
* @param seed Seed for sampling.
733733
* @group dfops
734734
*/
@@ -743,13 +743,24 @@ class DataFrame private[sql](
743743
/**
744744
* Randomly splits this [[DataFrame]] with the provided weights.
745745
*
746-
* @param weights weights for splits, will be normalized if they don't sum to 1
746+
* @param weights weights for splits, will be normalized if they don't sum to 1.
747747
* @group dfops
748748
*/
749749
def randomSplit(weights: Array[Double]): Array[DataFrame] = {
750750
randomSplit(weights, Utils.random.nextLong)
751751
}
752752

753+
/**
754+
* Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api.
755+
*
756+
* @param weights weights for splits, will be normalized if they don't sum to 1.
757+
* @param seed Seed for sampling.
758+
* @group dfops
759+
*/
760+
def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
761+
randomSplit(weights.toArray, seed)
762+
}
763+
753764
/**
754765
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
755766
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of

0 commit comments

Comments
 (0)