Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
restore original test
  • Loading branch information
cloud-fan committed Feb 11, 2020
commit 9420d0e63018e2e3c4f7e8d228a358df1fd832b2
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,3 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
case _ => false
}
}

/**
* A test-only partitioning that just output the "given key / base" as partition id.
*/
case class PassThroughPartitioning(key: Attribute, base: Int, numPartitions: Int)
extends Partitioning {
assert(key.dataType == IntegerType)
override def satisfies0(required: Distribution): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,13 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val startIndices = ShufflePartitionsCoalescer.coalescePartitions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty neat.

Array(leftStats, rightStats),
firstPartitionIndex = nonSkewPartitionIndices.head,
lastPartitionIndex = nonSkewPartitionIndices.last,
// `lastPartitionIndex` is exclusive.
lastPartitionIndex = nonSkewPartitionIndices.last + 1,
advisoryTargetSize = conf.targetPostShuffleInputSize)
startIndices.indices.map { i =>
val startIndex = startIndices(i)
val endIndex = if (i == startIndices.length - 1) {
// the `endIndex` is exclusive.
// `endIndex` is exclusive.
nonSkewPartitionIndices.last + 1
} else {
startIndices(i + 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ import org.apache.spark.internal.Logging
object ShufflePartitionsCoalescer extends Logging {

/**
* Coalesce the same range of partitions (firstPartitionIndex to lastPartitionIndex, inclusive)
* from multiple shuffles. This method assumes that all the shuffles have the same number of
* partitions, and the partitions of same index will be read together by one task.
* Coalesce the same range of partitions (`firstPartitionIndex`` to `lastPartitionIndex`, the
* start is inclusive and the end is exclusive) from multiple shuffles. This method assumes that
* all the shuffles have the same number of partitions, and the partitions of same index will be
* read together by one task.
*
* The strategy used to determine the number of coalesced partitions is described as follows.
* To determine the number of coalesced partitions, we have a target size for a coalesced
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,6 @@ object ShuffleExchangeExec {
override def numPartitions: Int = 1
override def getPartition(key: Any): Int = 0
}
case PassThroughPartitioning(_, _, n) =>
new Partitioner {
override def numPartitions: Int = n
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
Expand All @@ -240,10 +235,6 @@ object ShuffleExchangeExec {
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case p: PassThroughPartitioning =>
val projection = UnsafeProjection.create(
Divide(p.key, Literal(p.base)) :: Nil, outputAttributes)
row => projection(row).getInt(0)
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,14 @@ package org.apache.spark.sql.execution.adaptive
import java.io.File
import java.net.URI

import scala.util.Random

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.{DataFrame, Dataset, QueryTest, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.physical.PassThroughPartitioning
import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ReusedSubqueryExec, SparkPlan}
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildRight, SortMergeJoinExec}
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

class AdaptiveQueryExecSuite
Expand Down Expand Up @@ -613,23 +603,23 @@ class AdaptiveQueryExecSuite
}
}

// TODO: we need a way to customize data distribution after shuffle, to improve test coverage
// of this case.
test("SPARK-29544: adaptive skew join with different join types") {
// Unfortunately, we can't remove the injected extension. The `SkewJoinTestStrategy` is
// harmless and only affects this test suite.
spark.extensions.injectPlannerStrategy(_ => SkewJoinTestStrategy)
def createRelation(partitionRowCount: Int*): DataFrame = {
val output = new StructType().add("key", "int").toAttributes
Dataset.ofRows(spark, SkewJoinTestSource(output, partitionRowCount))
}

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD.key -> "100",
SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR.key -> "2") {
withTempView("t1", "t2") {
createRelation(3100, 100, 3200, 300, 3300, 400, 500).createTempView("t1")
createRelation(3400, 200, 300, 2900, 3200, 100, 600).createTempView("t2")
SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "700") {
withTempView("skewData1", "skewData2") {
spark
.range(0, 1000, 1, 10)
.selectExpr("id % 2 as key1", "id as value1")
.createOrReplaceTempView("skewData1")
spark
.range(0, 1000, 1, 10)
.selectExpr("id % 1 as key2", "id as value2")
.createOrReplaceTempView("skewData2")

def checkSkewJoin(joins: Seq[SortMergeJoinExec], expectedNumPartitions: Int): Unit = {
assert(joins.size == 1 && joins.head.isSkewJoin)
Expand All @@ -643,55 +633,45 @@ class AdaptiveQueryExecSuite

// skewed inner join optimization
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM t1 join t2 ON t1.key = t2.key")
"SELECT * FROM skewData1 join skewData2 ON key1 = key2")
// left stats: [3496, 0, 0, 0, 4014]
// right stats:[6292, 0, 0, 0, 0]
// Partition 0: both left and right sides are skewed, and divide into 5 splits, so
// 5 x 5 sub-partitions.
// Partition 1: not skewed, just 1 partition.
// Partition 2: only left side is skewed, and divide into 5 splits, so
// 5 sub-partitions.
// Partition 3: only right side is skewed, and divide into 5 splits, so
// Partition 1, 2, 3: not skewed, and coalesced into 1 partition.
// Partition 4: only left side is skewed, and divide into 5 splits, so
// 5 sub-partitions.
// Partition 4: both left and right sides are skewed, and divide into 5 splits, so
// 5 x 5 sub-partitions.
// Partition 5, 6: not skewed, and coalesced into 1 partition.
// So total (25 + 1 + 5 + 5 + 25 + 1) partitions.
// So total (25 + 1 + 5) partitions.
val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan)
checkSkewJoin(innerSmj, 25 + 1 + 5 + 5 + 25 + 1)
checkSkewJoin(innerSmj, 25 + 1 + 5)

// skewed left outer join optimization
val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM t1 left outer join t2 ON t1.key = t2.key")
"SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2")
// left stats: [3496, 0, 0, 0, 4014]
// right stats:[6292, 0, 0, 0, 0]
// Partition 0: both left and right sides are skewed, but left join can't split right side,
// so only left side is divided into 5 splits, and thus 5 sub-partitions.
// Partition 1: not skewed, just 1 partition.
// Partition 2: only left side is skewed, and divide into 5 splits, so
// Partition 1, 2, 3: not skewed, and coalesced into 1 partition.
// Partition 4: only left side is skewed, and divide into 5 splits, so
// 5 sub-partitions.
// Partition 3: only right side is skewed, but left join can't split right side, so just
// 1 partition.
// Partition 4: both left and right sides are skewed, but left join can't split right side,
// so only left side is divided into 5 splits, and thus 5 sub-partitions.
// Partition 5, 6: not skewed, and coalesced into 1 partition.
// So total (5 + 1 + 5 + 1 + 5 + 1) partitions.
// So total (5 + 1 + 5) partitions.
val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan)
checkSkewJoin(leftSmj, 5 + 1 + 5 + 1 + 5 + 1)
checkSkewJoin(leftSmj, 5 + 1 + 5)

// skewed right outer join optimization
val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM t1 right outer join t2 ON t1.key = t2.key")
"SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2")
// left stats: [3496, 0, 0, 0, 4014]
// right stats:[6292, 0, 0, 0, 0]
// Partition 0: both left and right sides are skewed, but right join can't split left side,
// so only right side is divided into 5 splits, and thus 5 sub-partitions.
// Partition 1: not skewed, just 1 partition.
// Partition 2: only left side is skewed, but right join can't split left side, so just
// Partition 1, 2, 3: not skewed, and coalesced into 1 partition.
// Partition 4: only left side is skewed, but right join can't split left side, so just
// 1 partition.
// Partition 1 and 2 get coalesced.
// Partition 3: only right side is skewed, and divide into 5 splits, so
// 5 sub-partitions.
// Partition 4: both left and right sides are skewed, but right join can't split left side,
// so only right side is divided into 5 splits, and thus 5 sub-partitions.
// Partition 5, 6: not skewed, and coalesced into 1 partition.
// So total (5 + 1 + 5 + 5 + 1) partitions.
// So total (5 + 1 + 1) partitions.
val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan)
checkSkewJoin(rightSmj, 5 + 1 + 5 + 5 + 1)
checkSkewJoin(rightSmj, 5 + 1 + 1)
}
}
}
Expand Down Expand Up @@ -735,52 +715,3 @@ class AdaptiveQueryExecSuite
}
}

case class SkewJoinTestSource(output: Seq[Attribute], partitionRowCount: Seq[Int])
extends LeafNode {
override def computeStats(): Statistics = Statistics(Long.MaxValue)
}

case class SkewJoinTestSourceExec(output: Seq[Attribute], partitionRowCount: Seq[Int])
extends LeafExecNode {

override protected def doExecute(): RDD[InternalRow] = {
val sum = partitionRowCount.sum
sparkContext.makeRDD(Seq.empty[Byte], 10).mapPartitions { _ =>
val proj = UnsafeProjection.create(output, output)
val rand = new Random(TaskContext.getPartitionId())

// Each RDD partition generates different partition IDs, but overall the partition ID
// distribution respects the ratio specified in `partitionRowCount`.
Seq.fill(sum / 10) {
val value = rand.nextInt(sum)
var partId = -1
var currentSum = 0
var i = 0
while (partId == -1 && i < partitionRowCount.length) {
currentSum += partitionRowCount(i)
if (value < currentSum) partId = i
i += 1
}
// Increase the partition ID diversity to avoid the join outputing too many results.
InternalRow(rand.nextInt(50) + partId * 100)
}.iterator.map(proj)
}
}
}

object SkewJoinTestStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ScanOperation(projectList, filters, s: SkewJoinTestSource) =>
assert(projectList == s.output)
val sourceExec = SkewJoinTestSourceExec(s.output, s.partitionRowCount)
val withFilter = if (filters.isEmpty) {
sourceExec
} else {
FilterExec(filters.reduce(And), sourceExec)
}
ShuffleExchangeExec(
PassThroughPartitioning(s.output.head, 100, s.partitionRowCount.length),
withFilter) :: Nil
case _ => Nil
}
}