Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reynold's comments
  • Loading branch information
sameeragarwal committed Jan 7, 2016
commit 633683219022a3d7fa512bdbc3ea7ae6349fc7e1
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,32 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
}

test("randomSplit on reordered partitions") {
val n = 600
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the tests are run so frequently, I don't think you need to try these many times ... doing it once should be enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the just the size of the dataset. We do however test for 5 different seeds. Should I just test for 1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea 1 is fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also you can just run it twice and make sure the result is deterministic, i.e.

val a = df.randomSplit(...).toSeq.map(_.collect())
val b = df.randomSplit(...).toSeq.map(_.collect())
assert(a == b)

as long as these are scala collections, I think they will work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, but to be fair, this new test does test a new codepath (that of inserting a sampling operator after a shuffle)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't that the same code path?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once we implement sample pushdown in catalyst, it shouldn't be: http://research.microsoft.com/pubs/76565/sig99sam.pdf :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do u mean? the shuffle happens outside of catalyst, so the optimizer can't push it beneath it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be clear, i'm suggesting removing everything the previous test case already tests, and only keep

// Verify that the results are deterministic across multiple runs
val data = sparkContext.parallelize(1 to n, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
val splits = data.randomSplit(Array[Double](1, 2, 3), seed = 1)
val firstRun = splits.toSeq.map(_.collect().toSeq)
val secondRun = data.randomSplit(Array[Double](1, 2, 3), seed = 1).toSeq.map(_.collect().toSeq)
assert(firstRun == secondRun)

// This test ensures that randomSplit does not create overlapping splits even when the
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
// rows in each partition.
val data =
sparkContext.parallelize(1 to n, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
data.sort($"id").collect().toList, "incomplete or wrong split")

for (id <- splits.indices) {
assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty,
s"split $id overlaps with split ${(id + 1) % splits.length}")
}

val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
}
}

test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.stat.corr("a", "b", "pearson")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,6 @@ class HiveSparkSubmitSuite
runSparkSubmit(args)
}

test("SPARK-12662 fix DataFrame.randomSplit to avoid creating overlapping splits") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val args = Seq(
"--class", SPARK_12662.getClass.getName.stripSuffix("$"),
"--name", "SparkSQLConfTest",
"--master", "local-cluster[2,1,1024]",
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
"--driver-java-options", "-Dderby.system.durability=test",
unusedJar.toString)
runSparkSubmit(args)
}

// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = {
Expand Down Expand Up @@ -385,48 +372,3 @@ object SPARK_11009 extends QueryTest {
}
}
}

/**
* This object is used to test SPARK-12662: https://issues.apache.org/jira/browse/SPARK-12662.
* This test ensures that [[org.apache.spark.sql.DataFrame.randomSplit]] does not create overlapping
* splits even when the underlying dataframe doesn't guarantee a deterministic ordering of rows in
* each partition.
*/
object SPARK_12662 extends QueryTest {
import org.apache.spark.sql.functions._

protected var sqlContext: SQLContext = _

def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")

val sparkContext = new SparkContext(
new SparkConf()
.set("spark.sql.shuffle.partitions", "100"))

val hiveContext = new TestHiveContext(sparkContext)
sqlContext = hiveContext

try {
val n = 600
val data = sqlContext.range(n).toDF("id").repartition(200, col("id"))
val splits = data.randomSplit(Array[Double](1, 2, 3), seed = 1)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
data.sort(col("id")).collect().toList, "incomplete or wrong split")

for (id <- splits.indices) {
assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty,
s"split $id overlaps with split ${(id + 1) % splits.length}")
}

val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
} finally {
sparkContext.stop()
}
}
}