Skip to content
Prev Previous commit
Next Next commit
Fix tests.
  • Loading branch information
viirya committed May 11, 2018
commit 015e2ad739e5ad7fe6d1d1ef3c919661d8ac3d29
7 changes: 6 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2767,7 +2767,12 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def count(): Long = withAction("count", groupBy().count().queryExecution) { plan =>
plan.executeCollect().head.getLong(0)
val collected = plan.executeCollect()
if (collected.isEmpty) {
0
} else {
collected.head.getLong(0)
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark.range(-10, -9, -20, 1).select("id").count in DataFrameRangeSuite causes exception here. plan.executeCollect().head pulls empty iterator by calling next.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is caused by returning SinglePartition when there is no data (and therefore no partition). So I think we should fix it there and not here.

Copy link
Member Author

@viirya viirya May 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, making sense. Thanks.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext {
def computeChiSquareTest(): Double = {
val n = 10000
// Trigger a sort
val data = spark.range(0, n, 1, 1).sort('id.desc)
// Range has range partitioning in its output now. To have a range shuffle, we
// need to run a repartition first.
val data = spark.range(0, n, 1, 1).repartition(10).sort('id.desc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I am just curious, why is sort('id.desc) not causing a shuffle? Shouldn't it be ordered by 'id.asc without the sort?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test requires a range shuffle. Previously range has unknown output partitioning/ordering, so there is a range shuffle inserted before sort.

For now range has an ordered output, so planner doesn't insert the shuffle we need here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also confused here, the range output ordering is 'id.asc, which doesn't match 'id.desc how can we avoid shuffle?

Copy link
Member Author

@viirya viirya May 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because range reports it is just one partition now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then can we change the code to spark.range(0, n, 1, 10)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test uses SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION to change sample size per partition and check the chi-sq value. It samples just 1 point so the chi-sq value is expected to be high.

If we change it from 1 to 10 partition, the chi-sq value will changed too. Should we do this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, isn't spark.range(0, n, 1, 10) almost same as spark.range(0, n, 1, 1).repartition(10)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point.

This is query plan and partition size for spark.range(0, n, 1, 1).repartition(10).sort('id.desc), when we set SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION to 1:

== Physical Plan ==
*(2) Sort [id#15L DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(id#15L DESC NULLS LAST, 4)
   +- Exchange RoundRobinPartitioning(10)
      +- *(1) Range (0, 10000, step=1, splits=1)

1666, 3766, 2003, 2565

spark.range(0, n, 1, 10).sort('id.desc):

== Physical Plan ==
*(2) Sort [id#13L DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(id#13L DESC NULLS LAST, 4)
   +- *(1) Range (0, 10000, step=1, splits=10)

(2835, 2469, 2362, 2334)

Because repartition shuffles data with RoundRobinPartitioning, I guess that it makes the worse sampling for range exchange. Without repartition, Range's output is already range partitioning, so it can get sampling leading better range boundaries.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, so the 100 and 300 in this test are coupled with the physical execution. I feel the right way to test this is, instead of hardcoding 100 and 300, we should have a and b, and check if b > 3 * a or something.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By spark.range(0, n, 1, 10).sort('id.desc), there is no 3 times liner relation between a and b. As shown above, this is also evenly distribution, the chi-sq value is also under 100.

Here we need a redistribution on data to make sampling difficult. Previously, a repartition is added automatically before sort. Now range has correct output partition info, so the repattition must be added manually.

.selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect()

// Compute histogram for the number of records per partition post sort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
p.asInstanceOf[WholeStageCodegenExec].child.collect {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can we change the groupBy instead of the test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

case h: HashAggregateExec => h
}.nonEmpty).isDefined)
assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext {

test("debugCodegen") {
val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan)
Copy link
Contributor

@cloud-fan cloud-fan May 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we change to groupBy('id * 2)? We should try our best to keep what to test, and keep the shuffle in this query.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

assert(res.contains("Subtree 1 / 2"))
assert(res.contains("Subtree 2 / 2"))
assert(res.contains("Subtree 1 / 1"))
assert(res.contains("Object[]"))
}

test("debugCodegenStringSeq") {
val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan)
assert(res.length == 2)
assert(res.length == 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

assert(res.forall{ case (subtree, code) =>
subtree.contains("Range") && code.contains("Object[]")})
}
Expand Down