Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix DataFrame.randomSplit to avoid creating overlapping splits
  • Loading branch information
sameeragarwal committed Jan 7, 2016
commit 27288a30ebe4cea98050803c1dfc55cf6275d162
9 changes: 8 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1062,10 +1062,17 @@ class DataFrame private[sql](
* @since 1.4.0
*/
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic.
val logicalPlanWithLocalSort =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make it more concise, just call this "sorted" and then everything fits in one line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan)
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan))
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed,
logicalPlanWithLocalSort))
}.toArray
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,19 @@ class HiveSparkSubmitSuite
runSparkSubmit(args)
}

test("SPARK-12662 fix DataFrame.randomSplit to avoid creating overlapping splits") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to start a whole new process to test this? I think we can just run randomSplit in the normal DataFrameSuite?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have not figure out a case that can trigger the problem in the local mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sc.parallelize(1 to 10).mapPartitions(scala.util.Random.shuffle(_)).collect()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right. We missed it. It is a good one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's neat! Converted it into a unit test in DataFrameStatSuite.

val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val args = Seq(
"--class", SPARK_12662.getClass.getName.stripSuffix("$"),
"--name", "SparkSQLConfTest",
"--master", "local-cluster[2,1,1024]",
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
"--driver-java-options", "-Dderby.system.durability=test",
unusedJar.toString)
runSparkSubmit(args)
}

// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = {
Expand Down Expand Up @@ -372,3 +385,48 @@ object SPARK_11009 extends QueryTest {
}
}
}

/**
* This object is used to test SPARK-12662: https://issues.apache.org/jira/browse/SPARK-12662.
* This test ensures that [[org.apache.spark.sql.DataFrame.randomSplit]] does not create overlapping
* splits even when the underlying dataframe doesn't guarantee a deterministic ordering of rows in
* each partition.
*/
object SPARK_12662 extends QueryTest {
import org.apache.spark.sql.functions._

protected var sqlContext: SQLContext = _

def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")

val sparkContext = new SparkContext(
new SparkConf()
.set("spark.sql.shuffle.partitions", "100"))

val hiveContext = new TestHiveContext(sparkContext)
sqlContext = hiveContext

try {
val n = 600
val data = sqlContext.range(n).toDF("id").repartition(200, col("id"))
val splits = data.randomSplit(Array[Double](1, 2, 3), seed = 1)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
data.sort(col("id")).collect().toList, "incomplete or wrong split")

for (id <- splits.indices) {
assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty,
s"split $id overlaps with split ${(id + 1) % splits.length}")
}

val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
} finally {
sparkContext.stop()
}
}
}