Skip to content
Closed
Prev Previous commit
Next Next commit
optimize for full outer join
  • Loading branch information
chenghao-intel committed Jul 2, 2015
commit 491a89042c41058653e9b41297c14bb1da856f4d
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,19 @@ case object AllTuples extends Distribution
* [[Expression Expressions]] will be co-located. Based on the context, this
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
* There is also another constraint, the `clustering` value contains null will be considered
* as a valid value if `nullKeysSensitive` == true.
*
* For examples:
* JOIN KEYS: values contains null will be considered as invalid values, which means
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here values means the original value of the table or the intermediate value of the join?
is the null in original data of table also considered as invalid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be the input value. (Either the original data from table or the intermediate result(e.g. join outputs)).

Validity of the null in the original table, depends on the semantics, in Join, it's should also be invalid, but it's valid for Group BY.

It would an other optimization for repartition, contains null in the join keys.

* the tuples could be in different partition.
* GROUP BY KEYS: values contains null will be considered as the valid value, which means
* the tuples should be in the same partition.
*/
case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
case class ClusteredDistribution(
clustering: Seq[Expression],
nullKeysSensitive: Boolean) extends Distribution {

require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
Expand Down Expand Up @@ -157,7 +168,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredDistribution(requiredClustering) =>
case ClusteredDistribution(requiredClustering, false) =>
clusteringSet.subsetOf(requiredClustering.toSet)
case _ => false
}
Expand Down Expand Up @@ -201,7 +212,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case OrderedDistribution(requiredOrdering) =>
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering) =>
case ClusteredDistribution(requiredClustering, false) =>
clusteringSet.subsetOf(requiredClustering.toSet)
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ class DistributionSuite extends SparkFunSuite {

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('a, 'b, 'c)),
ClusteredDistribution(Seq('a, 'b, 'c), false),
true)

checkSatisfied(
HashPartitioning(Seq('b, 'c), 10),
ClusteredDistribution(Seq('a, 'b, 'c)),
ClusteredDistribution(Seq('a, 'b, 'c), false),
true)

checkSatisfied(
SinglePartition,
ClusteredDistribution(Seq('a, 'b, 'c)),
ClusteredDistribution(Seq('a, 'b, 'c), false),
true)

checkSatisfied(
Expand All @@ -72,12 +72,12 @@ class DistributionSuite extends SparkFunSuite {
// Cases which need an exchange between two data properties.
checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('b, 'c)),
ClusteredDistribution(Seq('b, 'c), false),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('d, 'e)),
ClusteredDistribution(Seq('d, 'e), false),
false)

checkSatisfied(
Expand Down Expand Up @@ -128,17 +128,17 @@ class DistributionSuite extends SparkFunSuite {

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('a, 'b, 'c)),
ClusteredDistribution(Seq('a, 'b, 'c), false),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('c, 'b, 'a)),
ClusteredDistribution(Seq('c, 'b, 'a), false),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('b, 'c, 'a, 'd)),
ClusteredDistribution(Seq('b, 'c, 'a, 'd), false),
true)

// Cases which need an exchange between two data properties.
Expand All @@ -158,12 +158,12 @@ class DistributionSuite extends SparkFunSuite {

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('a, 'b)),
ClusteredDistribution(Seq('a, 'b), false),
false)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('c, 'd)),
ClusteredDistribution(Seq('c, 'd), false),
false)

checkSatisfied(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ case class Aggregate(
if (groupingExpressions == Nil) {
AllTuples :: Nil
} else {
ClusteredDistribution(groupingExpressions) :: Nil
ClusteredDistribution(groupingExpressions, true) :: Nil
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,11 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
def addOperatorsIfNecessary(
partitioning: Partitioning,
rowOrdering: Seq[SortOrder],
child: SparkPlan): SparkPlan = {
child: SparkPlan,
alwaysShuffle: Boolean = false): SparkPlan = {
val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
val needsShuffle = child.outputPartitioning != partitioning

val needsShuffle = (child.outputPartitioning != partitioning) || alwaysShuffle

val withShuffle = if (needsShuffle) {
Exchange(partitioning, Nil, child)
Expand Down Expand Up @@ -326,8 +328,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
val fixedChildren = requirements.zipped.map {
case (AllTuples, rowOrdering, child) =>
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
case (ClusteredDistribution(clustering), rowOrdering, child) =>
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
case (ClusteredDistribution(clustering, nullKeySensitive), rowOrdering, child) =>
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child, nullKeySensitive)
case (OrderedDistribution(ordering), rowOrdering, child) =>
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ case class GeneratedAggregate(
if (groupingExpressions == Nil) {
AllTuples :: Nil
} else {
ClusteredDistribution(groupingExpressions) :: Nil
ClusteredDistribution(groupingExpressions, false) :: Nil
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ case class Window(
// This operator will be very expensive.
AllTuples :: Nil
} else {
ClusteredDistribution(windowSpec.partitionSpec) :: Nil
ClusteredDistribution(windowSpec.partitionSpec, true) :: Nil
}

// Since window functions are adding columns to the input rows, the child's outputPartitioning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,37 @@ case class ExternalSort(
override def outputOrdering: Seq[SortOrder] = sortOrder
}

/**
* :: DeveloperApi ::
* Computes the set of distinct input rows using a HashSet.
* @param partial when true the distinct operation is performed partially, per partition, without
* shuffling the data.
* @param child the input query plan.
*/
@DeveloperApi
case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

override def requiredChildDistribution: Seq[Distribution] =
if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output, true) :: Nil

protected override def doExecute(): RDD[Row] = {
child.execute().mapPartitions { iter =>
val hashSet = new scala.collection.mutable.HashSet[Row]()

var currentRow: Row = null
while (iter.hasNext) {
currentRow = iter.next()
if (!hashSet.contains(currentRow)) {
hashSet.add(currentRow.copy())
}
}

hashSet.iterator
}
}
}

/**
* :: DeveloperApi ::
* Return a new RDD that has exactly `numPartitions` partitions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ case class HashOuterJoin(
}

override def requiredChildDistribution: Seq[ClusteredDistribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
ClusteredDistribution(leftKeys, false) :: ClusteredDistribution(rightKeys, false) :: Nil

override def output: Seq[Attribute] = {
joinType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ case class LeftSemiJoinHash(
override val buildSide: BuildSide = BuildRight

override def requiredChildDistribution: Seq[ClusteredDistribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
ClusteredDistribution(leftKeys, false) :: ClusteredDistribution(rightKeys, false) :: Nil

override def output: Seq[Attribute] = left.output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ case class ShuffledHashJoin(
override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[ClusteredDistribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
ClusteredDistribution(leftKeys, false) :: ClusteredDistribution(rightKeys, false) :: Nil

protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ case class SortMergeJoin(
override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
ClusteredDistribution(leftKeys, false) :: ClusteredDistribution(rightKeys, false) :: Nil

// this is to manually construct an ordering that can be used to compare keys from both sides
private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@ import org.apache.spark.sql.{Row, SQLConf, execution}


class PlannerSuite extends SparkFunSuite {
test("multiway full outer join") {
val planned = testData
.join(testData2, testData("key") === testData2("a"), "outer")
.join(testData3, testData("key") === testData3("a"), "outer")
.queryExecution.executedPlan
val exchanges = planned.collect { case n: Exchange => n }

assert(exchanges.size === 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is these changs doesn't effect to

testData
  .join(testData2, testData("key") === testData2("a"), "outer")
  .join(testData2, testData("a") === testData3("a"), "outer")

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires some further change I think, @yhuai should have some idea on this.

}

test("full outer join followed by aggregation") {
val planned = testData
.join(testData2, testData("key") === testData2("a"), "outer") // join key testData('key)
.groupBy(testData("key")).agg(testData("key"), count("a")) // group by key testData('key)
.queryExecution.executedPlan
val exchanges = planned.collect { case n: Exchange => n }

assert(exchanges.size === 3)
}

test("unions are collapsed") {
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
Expand Down