diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d80d94a58834..5a3aaf1d09c1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -19,29 +19,24 @@ package org.apache.spark.rdd import java.util.Random -import scala.collection.{mutable, Map} -import scala.collection.mutable.ArrayBuffer -import scala.language.implicitConversions -import scala.reflect.{classTag, ClassTag} - import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import org.apache.hadoop.io.{Writable, BytesWritable, NullWritable, Text} import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.io.{BytesWritable, NullWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat - -import org.apache.spark._ import org.apache.spark.Partitioner._ +import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.partial.BoundedDouble -import org.apache.spark.partial.CountEvaluator -import org.apache.spark.partial.GroupedCountEvaluator -import org.apache.spark.partial.PartialResult +import org.apache.spark.partial.{BoundedDouble, CountEvaluator, GroupedCountEvaluator, PartialResult} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, - SamplingUtils} +import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.{Map, mutable} +import scala.language.implicitConversions +import scala.reflect.{ClassTag, classTag} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -407,11 +402,26 @@ abstract class RDD[T: ClassTag]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T]( - this, new BernoulliCellSampler[T](x(0), x(1)), true, seed) + randomSampleWithRange(x(0), x(1), seed) }.toArray } + /** + * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability + * range. + * @param lb lower bound to use for the Bernoulli sampler + * @param ub upper bound to use for the Bernoulli sampler + * @param seed the seed for the Random number generator + * @return A random sub-sample of the RDD without replacement. + */ + private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { + this.mapPartitionsWithIndex { case (index, partition) => + val sampler = new BernoulliCellSampler[T](lb, ub) + sampler.setSeed(seed + index) + sampler.sample(partition) + } + } + /** * Return a fixed-size sampled subset of this RDD in an array * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5d5aba9644ff..5ae03adee90a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -279,10 +279,11 @@ package object dsl { Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) def sample( - fraction: Double, + lowerBound: Double, + upperBound: Double, withReplacement: Boolean = true, seed: Int = (math.random * 1000).toInt): LogicalPlan = - Sample(fraction, withReplacement, seed, logicalPlan) + Sample(lowerBound, upperBound, withReplacement, seed, logicalPlan) // TODO specify the output column names def generate( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index bbc94a7ab339..ca4f2f38636f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -300,8 +300,22 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } -case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) - extends UnaryNode { +/** + * Sample the dataset. + * + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the LogicalPlan + */ +case class Sample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ca6ae482eb2a..ee0fbcb8e714 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,31 +20,30 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.sql.DriverManager -import scala.collection.JavaConversions._ -import scala.language.implicitConversions -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal - import com.fasterxml.jackson.core.JsonFactory - import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar} +import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { @@ -711,7 +710,7 @@ class DataFrame private[sql]( * @group dfops */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { - Sample(fraction, withReplacement, seed, logicalPlan) + Sample(0.0, fraction, withReplacement, seed, logicalPlan) } /** @@ -966,6 +965,33 @@ class DataFrame private[sql]( schema, needsConversion = false) } + /** + * Randomly splits this DataFrame with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1 + * @param seed random seed + * + * @return split DataFrames in an array + */ + def randomSplit(weights: Array[Double], seed: Long = Utils.random.nextLong): Array[DataFrame] = { + val sum = weights.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + this.cache() + normalizedCumWeights.sliding(2).map { x => + new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan)) + }.toArray + } + + /** + * Randomly splits this [[DataFrame]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1 + * @group dfops + */ + def randomSplit(weights: Array[Double]): Array[DataFrame] = { + randomSplit(weights, Utils.random.nextLong) + } + /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 030ef118f75d..e3a633493b3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -300,8 +300,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Expand(projections, output, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil - case logical.Sample(fraction, withReplacement, seed, child) => - execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil + case logical.Sample(lb, ub, withReplacement, seed, child) => + execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d286fe81bee5..da032828425d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager @@ -25,8 +24,9 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.{CompletionIterator, MutablePair} +import org.apache.spark.{HashPartitioner, SparkEnv} /** * :: DeveloperApi :: @@ -63,16 +63,31 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { /** * :: DeveloperApi :: + * Sample the dataset. + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the QueryPlan */ @DeveloperApi -case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan) - extends UnaryNode -{ +case class Sample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: SparkPlan) + extends UnaryNode { override def output: Seq[Attribute] = child.output // TODO: How to pick seed? override def execute(): RDD[Row] = { - child.execute().map(_.copy()).sample(withReplacement, fraction, seed) + if (withReplacement) { + child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed) + } else { + child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5ec06d448e50..9f34f57fc4a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql -import scala.language.postfixOps - import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.sql +import org.apache.spark.sql.test.TestSQLContext.{logicalPlanToSparkQuery, sql} +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, TestSQLContext} +import org.apache.spark.sql.types._ + +import scala.language.postfixOps class DataFrameSuite extends QueryTest { @@ -391,6 +390,23 @@ class DataFrameSuite extends QueryTest { Row(null, null)) } + test("SPARK-7156 add randomSplit") { + val n = 600 + val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array(1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == data.collect().toList, + "incomplete or wrong split") + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("count") { assert(testData2.count() === testData2.map(_ => 1).count()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0ea6d57b816c..d4f8f2e3f7be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -887,13 +887,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, - relation) + Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, relation) case Token("TOK_TABLEBUCKETSAMPLE", Token(numerator, Nil) :: Token(denominator, Nil) :: Nil) => val fraction = numerator.toDouble / denominator.toDouble - Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) case a: ASTNode => throw new NotImplementedError( s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :