Skip to content
Closed

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,157 +18,193 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical._

/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._

class DistributionSuite extends SparkFunSuite {

protected def checkSatisfied(
protected def checkGap(
inputPartitioning: Partitioning,
requiredDistribution: Distribution,
satisfied: Boolean) {
if (inputPartitioning.satisfies(requiredDistribution) != satisfied) {
expectedGap: Gap) {
val gap = inputPartitioning.gap(requiredDistribution)
if (expectedGap != gap) {
fail(
s"""
|== Input Partitioning ==
|$inputPartitioning
|== Required Distribution ==
|$requiredDistribution
|== Does input partitioning satisfy required distribution? ==
|Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)}
|Expected $expectedGap got $gap
""".stripMargin)
}
}

test("HashPartitioning is the output partitioning") {
// Cases which do not need an exchange between two data properties.
checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
test("UnspecifiedDistribution is the required distribution") {
checkGap(
HashPartition(Seq('a, 'b, 'c)),
UnspecifiedDistribution,
true)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('a, 'b, 'c)),
true)

checkSatisfied(
HashPartitioning(Seq('b, 'c), 10),
ClusteredDistribution(Seq('a, 'b, 'c)),
true)

checkSatisfied(
SinglePartition,
ClusteredDistribution(Seq('a, 'b, 'c)),
true)

checkSatisfied(
SinglePartition,
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
true)

// Cases which need an exchange between two data properties.
checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('b, 'c)),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('d, 'e)),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
AllTuples,
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
false)

checkSatisfied(
HashPartitioning(Seq('b, 'c), 10),
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
false)

// TODO: We should check functional dependencies
/*
checkSatisfied(
ClusteredDistribution(Seq('b)),
ClusteredDistribution(Seq('b + 1)),
true)
*/
}
NoGap)

checkGap(
SinglePartition(),
UnspecifiedDistribution,
NoGap)

test("RangePartitioning is the output partitioning") {
// Cases which do not need an exchange between two data properties.
checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
checkGap(
RangePartition(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
UnspecifiedDistribution,
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
OrderedDistribution(Seq('a.asc, 'b.asc)),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('a, 'b, 'c)),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('c, 'b, 'a)),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('b, 'c, 'a, 'd)),
true)

// Cases which need an exchange between two data properties.
// TODO: We can have an optimization to first sort the dataset
// by a.asc and then sort b, and c in a partition. This optimization
// should tradeoff the benefit of a less number of Exchange operators
// and the parallelism.
checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)),
false)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
OrderedDistribution(Seq('b.asc, 'a.asc)),
false)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('a, 'b)),
false)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('c, 'd)),
false)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
AllTuples,
false)
NoGap)

checkGap(
HashPartition(Seq('a, 'b, 'c)),
UnspecifiedDistribution,
NoGap)

checkGap(
HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false),
UnspecifiedDistribution,
NoGap)
}

test("ClusteredDistribution is the required distribution") {
checkGap(
UnknownPartitioning,
ClusteredDistribution(Seq('a, 'b, 'c), true),
RepartitionKey(Seq('a, 'b, 'c)))

checkGap(
HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(true),
ClusteredDistribution(Seq('a, 'b, 'c), nullKeysSensitive = true),
RepartitionKey(Seq('a, 'b, 'c)))

checkGap(
HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false),
ClusteredDistribution(Seq('a, 'b, 'c), nullKeysSensitive = true),
NoGap)

checkGap(
HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false),
ClusteredDistribution(Seq('a, 'b, 'c), nullKeysSensitive = false),
NoGap)

checkGap(
HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false),
ClusteredDistribution(Seq('a, 'b, 'c),
nullKeysSensitive = true,
Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
SortKeyWithinPartition(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))))

checkGap(
HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false),
ClusteredDistribution(Seq('a, 'b, 'c),
nullKeysSensitive = false,
Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
SortKeyWithinPartition(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))))

checkGap(
HashPartition(Seq('a, 'b)).withAdditionalNullClusterKeyGenerated(false),
ClusteredDistribution(Seq('a, 'b),
nullKeysSensitive = false,
Seq(SortOrder('b, Ascending), SortOrder('a, Ascending))),
SortKeyWithinPartition(Seq(SortOrder('b, Ascending), SortOrder('a, Ascending))))

checkGap(
HashPartition(Seq('a, 'b)).withAdditionalNullClusterKeyGenerated(false),
ClusteredDistribution(Seq('b, 'a),
nullKeysSensitive = false,
Seq(SortOrder('a, Ascending), SortOrder('b, Ascending))),
RepartitionKeyAndSort(Seq('b, 'a), Seq(SortOrder('a, Ascending), SortOrder('b, Ascending))))

checkGap(
HashPartitionWithSort(Seq('b, 'c), Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
ClusteredDistribution(Seq('b, 'c), nullKeysSensitive = false),
NoGap)

checkGap(
HashPartitionWithSort(Seq('b, 'c), Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
ClusteredDistribution(Seq('b, 'c),
nullKeysSensitive = false,
Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
NoGap)

checkGap(
HashPartitionWithSort(Seq('b, 'c), Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
ClusteredDistribution(Seq('c, 'd),
nullKeysSensitive = false,
Seq(SortOrder('e, Ascending), SortOrder('f, Ascending))),
RepartitionKeyAndSort(Seq('c, 'd), Seq(SortOrder('e, Ascending), SortOrder('f, Ascending))))

checkGap(
HashPartitionWithSort(Seq('b, 'c), Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))
.withAdditionalNullClusterKeyGenerated(true),
ClusteredDistribution(Seq('b, 'c), nullKeysSensitive = true),
RepartitionKey(Seq('b, 'c)))

checkGap(
HashPartitionWithSort(Seq('b, 'c),
Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))
.withAdditionalNullClusterKeyGenerated(false),
ClusteredDistribution(Seq('b, 'c), nullKeysSensitive = true),
NoGap)

checkGap(
HashPartitionWithSort(Seq('c, 'b),
Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))
.withAdditionalNullClusterKeyGenerated(false),
ClusteredDistribution(Seq('d, 'e), nullKeysSensitive = true),
RepartitionKey(Seq('d, 'e)))

checkGap(
HashPartitionWithSort(Seq('b, 'c),
Seq(SortOrder('b, Ascending),
SortOrder('c, Ascending))).withAdditionalNullClusterKeyGenerated(false),
ClusteredDistribution(Seq('b, 'c),
nullKeysSensitive = true,
Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
NoGap)
}

test("RangePartitioning is the required distribution") {
checkGap(
UnknownPartitioning,
OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))))

checkGap(
SinglePartition(),
OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))))

checkGap(
HashPartition(Seq('b, 'c)).withAdditionalNullClusterKeyGenerated(false),
OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))))

checkGap(
HashPartition(Seq('b, 'c)),
OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))))

checkGap(
HashPartitionWithSort(Seq('b, 'c),
Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))
.withAdditionalNullClusterKeyGenerated(false),
OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))))

checkGap(
RangePartition(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
NoGap)

checkGap(
RangePartition(Seq(SortOrder('c, Ascending), SortOrder('b, Ascending))),
OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))),
GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,17 @@ case class Aggregate(
child: SparkPlan)
extends UnaryNode {

override def outputPartitioning: Partitioning = if (partial) {
child.outputPartitioning
} else {
HashPartition(groupingExpressions)
}

override def requiredChildDistribution: List[Distribution] = {
if (partial) {
UnspecifiedDistribution :: Nil
} else {
if (groupingExpressions == Nil) {
AllTuples :: Nil
} else {
ClusteredDistribution(groupingExpressions) :: Nil
}
ClusteredDistribution(groupingExpressions, true) :: Nil
}
}

Expand Down
Loading