Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more changes
  • Loading branch information
imback82 committed Sep 5, 2020
commit 7fa91cadd7d8788a6ac855292edb9cfe7b415486
Original file line number Diff line number Diff line change
Expand Up @@ -2698,6 +2698,16 @@ object SQLConf {
.checkValue(_ >= 0, "The value must be non-negative.")
.createWithDefault(8)

val OPTIMIZE_SORT_MERGE_JOIN_WITH_PARTIAL_HASH_DISTRIBUTION =
buildConf("spark.sql.execution.sortMergeJoin.optimizePartialHashDistribution.enabled")
.internal()
.doc("Optimizes sort merge join if both side of join have partial hash distributions - " +
"the output partitioning is HashPartitioning and its expressions are a subset of join " +
"keys on the respective side - by eliminating the shuffle.")
.version("3.1.0")
.booleanConf
.createWithDefault(false)

val OPTIMIZE_NULL_AWARE_ANTI_JOIN =
buildConf("spark.sql.optimizeNullAwareAntiJoin")
.internal()
Expand Down Expand Up @@ -3357,6 +3367,9 @@ class SQLConf extends Serializable with Logging {
def coalesceBucketsInJoinMaxBucketRatio: Int =
getConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_MAX_BUCKET_RATIO)

def optimizeSortMergeJoinWithPartialHashDistribution: Boolean =
getConf(SQLConf.OPTIMIZE_SORT_MERGE_JOIN_WITH_PARTIAL_HASH_DISTRIBUTION)

def optimizeNullAwareAntiJoin: Boolean =
getConf(SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
import org.apache.spark.sql.execution.bucketing.CoalesceBucketsInJoin
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, RemoveShuffleExchangeForSortMergeJoin, ReuseExchange}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, OptimizeSortMergeJoinWithPartialHashDistribution, ReuseExchange}
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
Expand Down Expand Up @@ -344,7 +344,7 @@ object QueryExecution {
PlanSubqueries(sparkSession),
RemoveRedundantProjects(sparkSession.sessionState.conf),
EnsureRequirements(sparkSession.sessionState.conf),
RemoveShuffleExchangeForSortMergeJoin(sparkSession.sessionState.conf),
OptimizeSortMergeJoinWithPartialHashDistribution(sparkSession.sessionState.conf),
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf,
sparkSession.sessionState.columnarRules),
CollapseCodegenStages(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.exchange

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SortExec, SparkPlan}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf

/**
* This rule removes shuffle for the sort merge join if the following conditions are met:
* - The child of ShuffleExchangeExec has HashPartitioning with the same number of partitions
* as the other side of join.
* - The child of ShuffleExchangeExec has output partitioning which has the subset of
* join keys on the respective join side.
*
* If the above conditions are met, shuffle can be eliminated for the sort merge join
* because rows are sorted before join logic is applied.
*/
case class OptimizeSortMergeJoinWithPartialHashDistribution(conf: SQLConf) extends Rule[SparkPlan] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot implement this optimization in EnsureRequirements instead? Any reason to apply this rule after EnsureRequirements insert shuffles?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could you add fine-grained tests for this rule?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To do this inside EnsureRequirements.ensureDistributionAndOrdering, it would require a new Partitioning and Distribution that know both sides of join, so I didn't go that route. Doing this outside would be less intrusive, I thought. But please let me know if doing this inside EnsureRequirements makes more sense. Thanks.

This is done after EnsureRequirements since reordering keys may eliminate shuffles in which case this rule is not applied.

def apply(plan: SparkPlan): SparkPlan = {
if (!conf.optimizeSortMergeJoinWithPartialHashDistribution) {
return plan
}

plan.transformUp {
case s @ SortMergeJoinExec(_, _, _, _,
lSort @ SortExec(_, _,
ExtractShuffleExchangeExecChild(
lChild,
lChildOutputPartitioning: HashPartitioning),
_),
Comment on lines +48 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why we can't just pattern matching ShuffleExchangeExec(_, leftChild, _) here? It seems to be looking simpler to me.

rSort @ SortExec(_, _,
ExtractShuffleExchangeExecChild(
rChild,
rChildOutputPartitioning: HashPartitioning),
_),
false) if isPartialHashDistribution(
s.leftKeys, lChildOutputPartitioning, s.rightKeys, rChildOutputPartitioning) =>
// Remove ShuffleExchangeExec.
s.copy(left = lSort.copy(child = lChild), right = rSort.copy(child = rChild))
case other => other
}
}

/*
* Returns true if both HashPartitioning have the same number of partitions and
* their partitioning expressions are a subset of their respective join keys.
*/
private def isPartialHashDistribution(
leftKeys: Seq[Expression],
leftPartitioning: HashPartitioning,
rightKeys: Seq[Expression],
rightPartitioning: HashPartitioning): Boolean = {
val mapping = leftKeyToRightKeyMapping(leftKeys, rightKeys)
(leftPartitioning.numPartitions == rightPartitioning.numPartitions) &&
leftPartitioning.expressions.zip(rightPartitioning.expressions)
.forall {
case (le, re) => mapping.get(le.canonicalized)
.map(_.exists(_.semanticEquals(re)))
.getOrElse(false)
}
Comment on lines +74 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry if I miss anything, but I feel this might not be correct. We should make sure the leftPartitioning.expressions and rightPartitioning.expressions has same size, and the order of expressions matters, right?

expressions size is different, so we should not remove shuffle:

t1 has 1024 buckets on column (a)
t2 has 1024 buckets on columns (a, b)

SELECT *
FROM t1
JOIN t2
ON t1.a = t2.a AND t1.b = t2.b

expressions size is same, but order is wrong, so we should not remove shuffle:

t1 has 1024 buckets on column (a, b)
t2 has 1024 buckets on columns (b, a)

SELECT *
FROM t1
JOIN t2
ON t1.a = t2.a AND AND t1.a = t2.b AND t1.b = t2.a AND t1.b = t2.b

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I agree with your concerns for both cases. But, for the first example, only one side will be shuffled, so the rule should not kick in. For the second example, we have t1.a = t2.b AND t1.b = t2.a which matches the bucket ordering, so this should be also fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I miss anything:

But, for the first example, only one side will be shuffled, so the rule should not kick in.

If the number of buckets for t1 is less than number of shuffle partitions, shouldn't it shuffle both sides ? (in EnsureRequirements). So the rule kicks in here and removes both shuffles, but we shouldn't remove any shuffle here.

For the second example, we have t1.a = t2.b AND t1.b = t2.a which matches the bucket ordering, so this should be also fine.

I think it's unsafe if we do not shuffle both sides. HashPartitioning(Seq(a, b)) and HashPartitioning(Seq(b, a)) are not same thing, e.g. for tuple (a: 1, b: 2) it will be assigned to different buckets given current Murmur3Hash implementation.

Copy link
Contributor Author

@imback82 imback82 Sep 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the number of buckets for t1 is less than number of shuffle partitions, shouldn't it shuffle both sides ? (in EnsureRequirements). So the rule kicks in here and removes both shuffles, but we shouldn't remove any shuffle here.

You are right. Thanks for the catch!

I think it's unsafe if we do not shuffle both sides. HashPartitioning(Seq(a, b)) and HashPartitioning(Seq(b, a)) are not same thing, e.g. for tuple (a: 1, b: 2) it will be assigned to different buckets given current Murmur3Hash implementation.

Yes, I understand they produce different hash values. However, it has the join condition t1.a = t2.b AND t1.b = t2.a. On the other hand, this rule will not be applied if the condition was t1.a = t2.a AND t1.b = t2.b. Please let me know if I missed something. Thanks!

}

/*
* Returns a mapping from left key to right key if there is a one-to-one mapping.
* Otherwise, returns None.
*/
private def leftKeyToRightKeyMapping(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression]): Map[Expression, Seq[Expression]] = {
assert(leftKeys.length == rightKeys.length)
val mapping = mutable.Map.empty[Expression, Seq[Expression]]
leftKeys.zip(rightKeys).foreach {
case (leftKey, rightKey) =>
val key = leftKey.canonicalized
mapping.get(key) match {
case Some(v) => mapping.put(key, v :+ rightKey)
case None => mapping.put(key, Seq(rightKey))
}
}
mapping.toMap
}
}

/**
* An extractor that extracts the child of ShuffleExchangeExec and the child's
* output partitioning.
*/
object ExtractShuffleExchangeExecChild {
def unapply(plan: SparkPlan): Option[(SparkPlan, Partitioning)] = {
plan match {
case s: ShuffleExchangeExec => Some(s.child, s.child.outputPartitioning)
case _ => None
}
}
}

This file was deleted.

55 changes: 55 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1314,4 +1314,59 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
}
}
}

test("SPARK-XXXXX: Optimize sort merge join with partial hash distribution") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withTable("t1", "t2") {
val df1 = (0 until 100).map(i => (i % 5, i % 13, i.toString)).toDF("i1", "j1", "k1")
val df2 = (0 until 100).map(i => (i % 3, i % 17, i.toString)).toDF("i2", "j2", "k2")
df1.write.format("parquet").bucketBy(8, "i1").saveAsTable("t1")
df2.write.format("parquet").bucketBy(8, "i2").saveAsTable("t2")
val t1 = spark.table("t1")
val t2 = spark.table("t2")

def verify(
f: => DataFrame,
numShufflesWithoutOptimization: Int,
numShufflesWithOptimization: Int): Unit = {
withSQLConf(
SQLConf.OPTIMIZE_SORT_MERGE_JOIN_WITH_PARTIAL_HASH_DISTRIBUTION.key -> "false") {
val dfWithoutOptimization = f
assert(dfWithoutOptimization.queryExecution.executedPlan.collect {
case s: ShuffleExchangeExec => s }.length == numShufflesWithoutOptimization)

withSQLConf(
SQLConf.OPTIMIZE_SORT_MERGE_JOIN_WITH_PARTIAL_HASH_DISTRIBUTION.key -> "true") {
val dfWithOptimization = f
assert(dfWithOptimization.queryExecution.executedPlan.collect {
case s: ShuffleExchangeExec => s }.length == numShufflesWithOptimization)
checkAnswer(dfWithOptimization, dfWithoutOptimization)
}
}
}

def verifyShuffleRemoved(f: => DataFrame): Unit = verify(f, 2, 0)
def verifyShuffleNotRemoved(f: => DataFrame): Unit = verify(f, 2, 2)

// Partial hash distribution by i1 and i2.
verifyShuffleRemoved(t1.join(t2, t1("i1") === t2("i2") && t1("j1") === t2("j2")))
verifyShuffleRemoved(t1.join(t2, t1("i1") === t2("i2") && t1("j1") + 1 === t2("j2")))
verifyShuffleRemoved(
t1.join(t2, t1("i1") === t2("i2") && t1("j1") === t2("j2") && t1("k1") === t2("k2")))
// Partial hash distribution by i1 and i2, but different join key orders.
verifyShuffleRemoved(t1.join(t2, t1("j1") === t2("j2") && t1("i1") === t2("i2")))
verifyShuffleRemoved(
t1.join(t2, t1("j1") === t2("j2") && t1("i1") === t2("i2") && t1("k1") === t2("k2")))
// Many-to-one mapping for join keys (right to left)
verifyShuffleRemoved(t1.join(t2, t1("i1") === t2("i2") && t1("j1") === t2("i2")))
// One-to-many mapping for join keys (left to right)
verifyShuffleRemoved(t1.join(t2, t1("i1") === t2("i2") && t1("i1") === t2("j2")))

// Join keys are not a subset of distribution.
verifyShuffleNotRemoved(t1.join(t2, t1("j1") === t2("j2")))
// The join key (i1 + 1) doesn't match with distribution expression.
verifyShuffleNotRemoved(t1.join(t2, t1("i1") + 1 === t2("i2") && t1("j1") === t2("j2")))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1012,37 +1012,4 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
}
}
}

test("terry") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withTable("t1", "t2") {
val df1 = (0 until 100).map(i => (i % 5, i % 13, i.toString)).toDF("i1", "j1", "k1")
val df2 = (0 until 100).map(i => (i % 5, i % 13, i.toString)).toDF("i2", "j2", "k2")
df1.write.format("parquet").bucketBy(8, "i1").saveAsTable("t1")
df2.write.format("parquet").bucketBy(8, "i2").saveAsTable("t2")
val t1 = spark.table("t1")
val t2 = spark.table("t2")
val joined = t1.join(t2, t1("i1") === t2("i2") && t1("j1") === t2("j2"))
joined.explain(true)
}
}
}

test("terry + 1") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withTable("t1", "t2") {
val df1 = (0 until 5).map(i => (i % 5, i % 13, i.toString)).toDF("i1", "j1", "k1")
val df2 = (0 until 5).map(i => (i % 3, i % 13, i.toString)).toDF("i2", "j2", "k2")
df1.write.format("parquet").bucketBy(8, "i1").saveAsTable("t1")
df2.write.format("parquet").bucketBy(8, "i2").saveAsTable("t2")
val t1 = spark.table("t1")
val t2 = spark.table("t2")
val joined = t1.join(t2, t1("i1") === t2("i2") && t1("j1") + 1 === t2("j2"))
joined.explain(true)
df1.show
df2.show
joined.show
}
}
}
}