Skip to content

Commit 9a6ac72

Browse files
committed
[SPARK-19601][SQL] Fix CollapseRepartition rule to preserve shuffle-enabled Repartition
### What changes were proposed in this pull request? Observed by felixcheung in #16739, when users use the shuffle-enabled `repartition` API, they expect the partition they got should be the exact number they provided, even if they call shuffle-disabled `coalesce` later. Currently, `CollapseRepartition` rule does not consider whether shuffle is enabled or not. Thus, we got the following unexpected result. ```Scala val df = spark.range(0, 10000, 1, 5) val df2 = df.repartition(10) assert(df2.coalesce(13).rdd.getNumPartitions == 5) assert(df2.coalesce(7).rdd.getNumPartitions == 5) assert(df2.coalesce(3).rdd.getNumPartitions == 3) ``` This PR is to fix the issue. We preserve shuffle-enabled Repartition. ### How was this patch tested? Added a test case Author: Xiao Li <[email protected]> Closes #16933 from gatorsmile/CollapseRepartition.
1 parent 5f7d835 commit 9a6ac72

File tree

7 files changed

+178
-49
lines changed

7 files changed

+178
-49
lines changed

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2592,8 +2592,8 @@ test_that("coalesce, repartition, numPartitions", {
25922592

25932593
df2 <- repartition(df1, 10)
25942594
expect_equal(getNumPartitions(df2), 10)
2595-
expect_equal(getNumPartitions(coalesce(df2, 13)), 5)
2596-
expect_equal(getNumPartitions(coalesce(df2, 7)), 5)
2595+
expect_equal(getNumPartitions(coalesce(df2, 13)), 10)
2596+
expect_equal(getNumPartitions(coalesce(df2, 7)), 7)
25972597
expect_equal(getNumPartitions(coalesce(df2, 3)), 3)
25982598
})
25992599

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ package object dsl {
370370

371371
def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)
372372

373+
def coalesce(num: Integer): LogicalPlan =
374+
Repartition(num, shuffle = false, logicalPlan)
375+
373376
def repartition(num: Integer): LogicalPlan =
374377
Repartition(num, shuffle = true, logicalPlan)
375378

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -564,27 +564,23 @@ object CollapseProject extends Rule[LogicalPlan] {
564564
}
565565

566566
/**
567-
* Combines adjacent [[Repartition]] and [[RepartitionByExpression]] operator combinations
568-
* by keeping only the one.
569-
* 1. For adjacent [[Repartition]]s, collapse into the last [[Repartition]].
570-
* 2. For adjacent [[RepartitionByExpression]]s, collapse into the last [[RepartitionByExpression]].
571-
* 3. For a combination of [[Repartition]] and [[RepartitionByExpression]], collapse as a single
572-
* [[RepartitionByExpression]] with the expression and last number of partition.
567+
* Combines adjacent [[RepartitionOperation]] operators
573568
*/
574569
object CollapseRepartition extends Rule[LogicalPlan] {
575570
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
576-
// Case 1
577-
case Repartition(numPartitions, shuffle, Repartition(_, _, child)) =>
578-
Repartition(numPartitions, shuffle, child)
579-
// Case 2
580-
case RepartitionByExpression(exprs, RepartitionByExpression(_, child, _), numPartitions) =>
581-
RepartitionByExpression(exprs, child, numPartitions)
582-
// Case 3
583-
case Repartition(numPartitions, _, r: RepartitionByExpression) =>
584-
r.copy(numPartitions = numPartitions)
585-
// Case 3
586-
case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) =>
587-
RepartitionByExpression(exprs, child, numPartitions)
571+
// Case 1: When a Repartition has a child of Repartition or RepartitionByExpression,
572+
// 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child
573+
// enables the shuffle. Returns the child node if the last numPartitions is bigger;
574+
// otherwise, keep unchanged.
575+
// 2) In the other cases, returns the top node with the child's child
576+
case r @ Repartition(_, _, child: RepartitionOperation) => (r.shuffle, child.shuffle) match {
577+
case (false, true) => if (r.numPartitions >= child.numPartitions) child else r
578+
case _ => r.copy(child = child.child)
579+
}
580+
// Case 2: When a RepartitionByExpression has a child of Repartition or RepartitionByExpression
581+
// we can remove the child.
582+
case r @ RepartitionByExpression(_, child: RepartitionOperation, _) =>
583+
r.copy(child = child.child)
588584
}
589585
}
590586

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -842,16 +842,24 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
842842
override def output: Seq[Attribute] = child.output
843843
}
844844

845+
/**
846+
* A base interface for [[RepartitionByExpression]] and [[Repartition]]
847+
*/
848+
abstract class RepartitionOperation extends UnaryNode {
849+
def shuffle: Boolean
850+
def numPartitions: Int
851+
override def output: Seq[Attribute] = child.output
852+
}
853+
845854
/**
846855
* Returns a new RDD that has exactly `numPartitions` partitions. Differs from
847856
* [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user
848857
* asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer
849858
* of the output requires some specific ordering or distribution of the data.
850859
*/
851860
case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
852-
extends UnaryNode {
861+
extends RepartitionOperation {
853862
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
854-
override def output: Seq[Attribute] = child.output
855863
}
856864

857865
/**
@@ -863,12 +871,12 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
863871
case class RepartitionByExpression(
864872
partitionExpressions: Seq[Expression],
865873
child: LogicalPlan,
866-
numPartitions: Int) extends UnaryNode {
874+
numPartitions: Int) extends RepartitionOperation {
867875

868876
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
869877

870878
override def maxRows: Option[Long] = child.maxRows
871-
override def output: Seq[Attribute] = child.output
879+
override def shuffle: Boolean = true
872880
}
873881

874882
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala

Lines changed: 137 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,47 +32,168 @@ class CollapseRepartitionSuite extends PlanTest {
3232

3333
val testRelation = LocalRelation('a.int, 'b.int)
3434

35+
36+
test("collapse two adjacent coalesces into one") {
37+
// Always respects the top coalesces amd removes useless coalesce below coalesce
38+
val query1 = testRelation
39+
.coalesce(10)
40+
.coalesce(20)
41+
val query2 = testRelation
42+
.coalesce(30)
43+
.coalesce(20)
44+
45+
val optimized1 = Optimize.execute(query1.analyze)
46+
val optimized2 = Optimize.execute(query2.analyze)
47+
val correctAnswer = testRelation.coalesce(20).analyze
48+
49+
comparePlans(optimized1, correctAnswer)
50+
comparePlans(optimized2, correctAnswer)
51+
}
52+
3553
test("collapse two adjacent repartitions into one") {
36-
val query = testRelation
54+
// Always respects the top repartition amd removes useless repartition below repartition
55+
val query1 = testRelation
56+
.repartition(10)
57+
.repartition(20)
58+
val query2 = testRelation
59+
.repartition(30)
60+
.repartition(20)
61+
62+
val optimized1 = Optimize.execute(query1.analyze)
63+
val optimized2 = Optimize.execute(query2.analyze)
64+
val correctAnswer = testRelation.repartition(20).analyze
65+
66+
comparePlans(optimized1, correctAnswer)
67+
comparePlans(optimized2, correctAnswer)
68+
}
69+
70+
test("coalesce above repartition") {
71+
// Remove useless coalesce above repartition
72+
val query1 = testRelation
3773
.repartition(10)
74+
.coalesce(20)
75+
76+
val optimized1 = Optimize.execute(query1.analyze)
77+
val correctAnswer1 = testRelation.repartition(10).analyze
78+
79+
comparePlans(optimized1, correctAnswer1)
80+
81+
// No change in this case
82+
val query2 = testRelation
83+
.repartition(30)
84+
.coalesce(20)
85+
86+
val optimized2 = Optimize.execute(query2.analyze)
87+
val correctAnswer2 = query2.analyze
88+
89+
comparePlans(optimized2, correctAnswer2)
90+
}
91+
92+
test("repartition above coalesce") {
93+
// Always respects the top repartition amd removes useless coalesce below repartition
94+
val query1 = testRelation
95+
.coalesce(10)
96+
.repartition(20)
97+
val query2 = testRelation
98+
.coalesce(30)
3899
.repartition(20)
39100

40-
val optimized = Optimize.execute(query.analyze)
101+
val optimized1 = Optimize.execute(query1.analyze)
102+
val optimized2 = Optimize.execute(query2.analyze)
41103
val correctAnswer = testRelation.repartition(20).analyze
42104

43-
comparePlans(optimized, correctAnswer)
105+
comparePlans(optimized1, correctAnswer)
106+
comparePlans(optimized2, correctAnswer)
44107
}
45108

46-
test("collapse repartition and repartitionBy into one") {
47-
val query = testRelation
109+
test("repartitionBy above repartition") {
110+
// Always respects the top repartitionBy amd removes useless repartition
111+
val query1 = testRelation
48112
.repartition(10)
49113
.distribute('a)(20)
114+
val query2 = testRelation
115+
.repartition(30)
116+
.distribute('a)(20)
50117

51-
val optimized = Optimize.execute(query.analyze)
118+
val optimized1 = Optimize.execute(query1.analyze)
119+
val optimized2 = Optimize.execute(query2.analyze)
52120
val correctAnswer = testRelation.distribute('a)(20).analyze
53121

54-
comparePlans(optimized, correctAnswer)
122+
comparePlans(optimized1, correctAnswer)
123+
comparePlans(optimized2, correctAnswer)
55124
}
56125

57-
test("collapse repartitionBy and repartition into one") {
58-
val query = testRelation
126+
test("repartitionBy above coalesce") {
127+
// Always respects the top repartitionBy amd removes useless coalesce below repartition
128+
val query1 = testRelation
129+
.coalesce(10)
130+
.distribute('a)(20)
131+
val query2 = testRelation
132+
.coalesce(30)
59133
.distribute('a)(20)
60-
.repartition(10)
61134

62-
val optimized = Optimize.execute(query.analyze)
63-
val correctAnswer = testRelation.distribute('a)(10).analyze
135+
val optimized1 = Optimize.execute(query1.analyze)
136+
val optimized2 = Optimize.execute(query2.analyze)
137+
val correctAnswer = testRelation.distribute('a)(20).analyze
64138

65-
comparePlans(optimized, correctAnswer)
139+
comparePlans(optimized1, correctAnswer)
140+
comparePlans(optimized2, correctAnswer)
141+
}
142+
143+
test("repartition above repartitionBy") {
144+
// Always respects the top repartition amd removes useless distribute below repartition
145+
val query1 = testRelation
146+
.distribute('a)(10)
147+
.repartition(20)
148+
val query2 = testRelation
149+
.distribute('a)(30)
150+
.repartition(20)
151+
152+
val optimized1 = Optimize.execute(query1.analyze)
153+
val optimized2 = Optimize.execute(query2.analyze)
154+
val correctAnswer = testRelation.repartition(20).analyze
155+
156+
comparePlans(optimized1, correctAnswer)
157+
comparePlans(optimized2, correctAnswer)
158+
159+
}
160+
161+
test("coalesce above repartitionBy") {
162+
// Remove useless coalesce above repartition
163+
val query1 = testRelation
164+
.distribute('a)(10)
165+
.coalesce(20)
166+
167+
val optimized1 = Optimize.execute(query1.analyze)
168+
val correctAnswer1 = testRelation.distribute('a)(10).analyze
169+
170+
comparePlans(optimized1, correctAnswer1)
171+
172+
// No change in this case
173+
val query2 = testRelation
174+
.distribute('a)(30)
175+
.coalesce(20)
176+
177+
val optimized2 = Optimize.execute(query2.analyze)
178+
val correctAnswer2 = query2.analyze
179+
180+
comparePlans(optimized2, correctAnswer2)
66181
}
67182

68183
test("collapse two adjacent repartitionBys into one") {
69-
val query = testRelation
184+
// Always respects the top repartitionBy
185+
val query1 = testRelation
70186
.distribute('b)(10)
71187
.distribute('a)(20)
188+
val query2 = testRelation
189+
.distribute('b)(30)
190+
.distribute('a)(20)
72191

73-
val optimized = Optimize.execute(query.analyze)
192+
val optimized1 = Optimize.execute(query1.analyze)
193+
val optimized2 = Optimize.execute(query2.analyze)
74194
val correctAnswer = testRelation.distribute('a)(20).analyze
75195

76-
comparePlans(optimized, correctAnswer)
196+
comparePlans(optimized1, correctAnswer)
197+
comparePlans(optimized2, correctAnswer)
77198
}
78199
}

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2441,11 +2441,11 @@ class Dataset[T] private[sql](
24412441
}
24422442

24432443
/**
2444-
* Returns a new Dataset that has exactly `numPartitions` partitions.
2445-
* Similar to coalesce defined on an `RDD`, this operation results in a narrow dependency, e.g.
2446-
* if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
2447-
* the 100 new partitions will claim 10 of the current partitions. If a larger number of
2448-
* partitions is requested, it will stay at the current number of partitions.
2444+
* Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
2445+
* are requested. If a larger number of partitions is requested, it will stay at the current
2446+
* number of partitions. Similar to coalesce defined on an `RDD`, this operation results in
2447+
* a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not
2448+
* be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
24492449
*
24502450
* However, if you're doing a drastic coalesce, e.g. to numPartitions = 1,
24512451
* this may result in your computation taking place on fewer nodes than

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,12 @@ class PlannerSuite extends SharedSQLContext {
242242
val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5)
243243
def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length
244244
assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3)
245-
assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1)
245+
assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2)
246246
doubleRepartitioned.queryExecution.optimizedPlan match {
247-
case r: Repartition =>
248-
assert(r.numPartitions === 5)
249-
assert(r.shuffle === false)
247+
case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) =>
248+
assert(numPartitions === 5)
249+
assert(shuffle === false)
250+
assert(shuffleChild === true)
250251
}
251252
}
252253

0 commit comments

Comments
 (0)