Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,24 @@ package org.apache.spark.rdd

import java.util.Random

import scala.collection.{mutable, Map}
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.reflect.{classTag, ClassTag}

import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
import org.apache.hadoop.io.{Writable, BytesWritable, NullWritable, Text}
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.{BytesWritable, NullWritable, Text}
import org.apache.hadoop.mapred.TextOutputFormat

import org.apache.spark._
import org.apache.spark.Partitioner._
import org.apache.spark._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.partial.{BoundedDouble, CountEvaluator, GroupedCountEvaluator, PartialResult}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler,
SamplingUtils}
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}

import scala.collection.mutable.ArrayBuffer
import scala.collection.{Map, mutable}
import scala.language.implicitConversions
import scala.reflect.{ClassTag, classTag}

/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
Expand Down Expand Up @@ -407,11 +402,26 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new PartitionwiseSampledRDD[T, T](
this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
randomSampleWithRange(x(0), x(1), seed)
}.toArray
}

/**
* Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
* range.
* @param lb lower bound to use for the Bernoulli sampler
* @param ub upper bound to use for the Bernoulli sampler
* @param seed the seed for the Random number generator
* @return A random sub-sample of the RDD without replacement.
*/
private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = {
this.mapPartitionsWithIndex { case (index, partition) =>
val sampler = new BernoulliCellSampler[T](lb, ub)
sampler.setSeed(seed + index)
sampler.sample(partition)
}
}

/**
* Return a fixed-size sampled subset of this RDD in an array
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,11 @@ package object dsl {
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)

def sample(
fraction: Double,
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean = true,
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)
Sample(lowerBound, upperBound, withReplacement, seed, logicalPlan)

// TODO specify the output column names
def generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,22 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
/**
* Sample the dataset.
*
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the LogicalPlan
*/
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output
}
Expand Down
52 changes: 39 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,30 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import java.sql.DriverManager

import scala.collection.JavaConversions._
import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import com.fasterxml.jackson.core.JsonFactory

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource}
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

import scala.collection.JavaConversions._
import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal


private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
Expand Down Expand Up @@ -711,7 +710,7 @@ class DataFrame private[sql](
* @group dfops
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
Sample(fraction, withReplacement, seed, logicalPlan)
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}

/**
Expand Down Expand Up @@ -966,6 +965,33 @@ class DataFrame private[sql](
schema, needsConversion = false)
}

/**
* Randomly splits this DataFrame with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1
* @param seed random seed
*
* @return split DataFrames in an array
*/
def randomSplit(weights: Array[Double], seed: Long = Utils.random.nextLong): Array[DataFrame] = {
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
this.cache()
normalizedCumWeights.sliding(2).map { x =>
new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
}.toArray
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1
* @group dfops
*/
def randomSplit(weights: Array[Double]): Array[DataFrame] = {
randomSplit(weights, Utils.random.nextLong)
}

/**
* Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Expand(projections, output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

package org.apache.spark.sql.execution

import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.{HashPartitioner, SparkEnv}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -63,16 +63,31 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {

/**
* :: DeveloperApi ::
* Sample the dataset.
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the QueryPlan
*/
@DeveloperApi
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
extends UnaryNode
{
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan)
extends UnaryNode {
override def output: Seq[Attribute] = child.output

// TODO: How to pick seed?
override def execute(): RDD[Row] = {
child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
if (withReplacement) {
child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
} else {
child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
}
}
}

Expand Down
28 changes: 22 additions & 6 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.spark.sql

import scala.language.postfixOps

import org.apache.spark.sql.functions._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala inports should be first

import org.apache.spark.sql.types._
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.sql
import org.apache.spark.sql.test.TestSQLContext.{logicalPlanToSparkQuery, sql}
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, TestSQLContext}
import org.apache.spark.sql.types._

import scala.language.postfixOps


class DataFrameSuite extends QueryTest {
Expand Down Expand Up @@ -391,6 +390,23 @@ class DataFrameSuite extends QueryTest {
Row(null, null))
}

test("SPARK-7156 add randomSplit") {
val n = 600
val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array(1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == data.collect().toList,
"incomplete or wrong split")

val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
}
}

test("count") {
assert(testData2.count() === testData2.map(_ => 1).count())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,13 +887,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
&& fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
s"Sampling fraction ($fraction) must be on interval [0, 100]")
Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
relation)
Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, relation)
case Token("TOK_TABLEBUCKETSAMPLE",
Token(numerator, Nil) ::
Token(denominator, Nil) :: Nil) =>
val fraction = numerator.toDouble / denominator.toDouble
Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
case a: ASTNode =>
throw new NotImplementedError(
s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :
Expand Down