Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,6 @@ sealed trait Partitioning {
* produced by `A` could have also been produced by `B`.
*/
def guarantees(other: Partitioning): Boolean = this == other

def withNumPartitions(newNumPartitions: Int): Partitioning = {
throw new IllegalStateException(
s"It is not allowed to call withNumPartitions method of a ${this.getClass.getSimpleName}")
}
}

object Partitioning {
Expand Down Expand Up @@ -254,9 +249,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}

override def withNumPartitions(newNumPartitions: Int): HashPartitioning = {
HashPartitioning(expressions, newNumPartitions)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ case class Exchange(
// update the number of post-shuffle partitions.
specifiedPartitionStartIndices.foreach { indices =>
assert(newPartitioning.isInstanceOf[HashPartitioning])
newPartitioning = newPartitioning.withNumPartitions(indices.length)
newPartitioning = UnknownPartitioning(indices.length)
}
new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
}
Expand All @@ -262,7 +262,7 @@ case class Exchange(

object Exchange {
def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
Exchange(newPartitioning, child, None: Option[ExchangeCoordinator])
Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
}
}

Expand Down Expand Up @@ -315,7 +315,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
child.outputPartitioning match {
case hash: HashPartitioning => true
case collection: PartitioningCollection =>
collection.partitionings.exists(_.isInstanceOf[HashPartitioning])
collection.partitionings.forall(_.isInstanceOf[HashPartitioning])
case _ => false
}
}
Expand Down Expand Up @@ -416,28 +416,48 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
// First check if the existing partitions of the children all match. This means they are
// partitioned by the same partitioning into the same number of partitions. In that case,
// don't try to make them match `defaultPartitions`, just use the existing partitioning.
// TODO: this should be a cost based decision. For example, a big relation should probably
// maintain its existing number of partitions and smaller partitions should be shuffled.
// defaultPartitions is arbitrary.
val numPartitions = children.head.outputPartitioning.numPartitions
val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
case (child, distribution) => {
child.outputPartitioning.guarantees(
createPartitioning(distribution, numPartitions))
createPartitioning(distribution, maxChildrenNumPartitions))
}
}

children = if (useExistingPartitioning) {
// We do not need to shuffle any child's output.
children
} else {
// We need to shuffle at least one child's output.
// Now, we will determine the number of partitions that will be used by created
// partitioning schemes.
val numPartitions = {
// Let's see if we need to shuffle all child's outputs when we use
// maxChildrenNumPartitions.
val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
case (child, distribution) => {
!child.outputPartitioning.guarantees(
createPartitioning(distribution, maxChildrenNumPartitions))
}
}
// If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
// number of partitions. Otherwise, we use maxChildrenNumPartitions.
if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
}

children.zip(requiredChildDistributions).map {
case (child, distribution) => {
val targetPartitioning =
createPartitioning(distribution, defaultNumPreShufflePartitions)
createPartitioning(distribution, numPartitions)
if (child.outputPartitioning.guarantees(targetPartitioning)) {
child
} else {
Exchange(targetPartitioning, child)
child match {
// If child is an exchange, we replace it with
// a new one having targetPartitioning.
case Exchange(_, c, _) => Exchange(targetPartitioning, c)
case _ => Exchange(targetPartitioning, child)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nongli Can you take a look at here? If one side of the join is shuffled, I am trying to avoid of shuffling that side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Can you update/remove the TODO on line 419?

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution

import java.util.{Map => JMap, HashMap => JHashMap}
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -97,6 +98,7 @@ private[sql] class ExchangeCoordinator(
* Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be
* called in the `doPrepare` method of an [[Exchange]] operator.
*/
@GuardedBy("this")
def registerExchange(exchange: Exchange): Unit = synchronized {
exchanges += exchange
}
Expand All @@ -109,7 +111,7 @@ private[sql] class ExchangeCoordinator(
*/
private[sql] def estimatePartitionStartIndices(
mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = {
// If we have mapOutputStatistics.length <= numExchange, it is because we do not submit
// If we have mapOutputStatistics.length < numExchange, it is because we do not submit
// a stage when the number of partitions of this dependency is 0.
assert(mapOutputStatistics.length <= numExchanges)

Expand All @@ -121,6 +123,8 @@ private[sql] class ExchangeCoordinator(
val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum
// The max at here is to make sure that when we have an empty table, we
// only have a single post-shuffle partition.
// There is no particular reason that we pick 16. We just need a number to
// prevent maxPostShuffleInputSize from being set to 0.
val maxPostShuffleInputSize =
math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16)
math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize)
Expand All @@ -135,6 +139,12 @@ private[sql] class ExchangeCoordinator(
// Make sure we do get the same number of pre-shuffle partitions for those stages.
val distinctNumPreShufflePartitions =
mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct
// The reason that we are expecting a single value of the number of pre-shuffle partitions
// is that when we add Exchanges, we set the number of pre-shuffle partitions
// (i.e. map output partitions) using a static setting, which is the value of
// spark.sql.shuffle.partitions. Even if two input RDDs are having different
// number of partitions, they will have the same number of pre-shuffle partitions
// (i.e. map output partitions).
assert(
distinctNumPreShufflePartitions.length == 1,
"There should be only one distinct value of the number pre-shuffle partitions " +
Expand Down Expand Up @@ -177,6 +187,7 @@ private[sql] class ExchangeCoordinator(
partitionStartIndices.toArray
}

@GuardedBy("this")
private def doEstimationIfNecessary(): Unit = synchronized {
// It is unlikely that this method will be called from multiple threads
// (when multiple threads trigger the execution of THIS physical)
Expand Down Expand Up @@ -209,11 +220,11 @@ private[sql] class ExchangeCoordinator(

// Wait for the finishes of those submitted map stages.
val mapOutputStatistics = new Array[MapOutputStatistics](submittedStageFutures.length)
i = 0
while (i < submittedStageFutures.length) {
var j = 0
while (j < submittedStageFutures.length) {
// This call is a blocking call. If the stage has not finished, we will wait at here.
mapOutputStatistics(i) = submittedStageFutures(i).get()
i += 1
mapOutputStatistics(j) = submittedStageFutures(j).get()
j += 1
}

// Now, we estimate partitionStartIndices. partitionStartIndices.length will be the
Expand All @@ -225,14 +236,14 @@ private[sql] class ExchangeCoordinator(
Some(estimatePartitionStartIndices(mapOutputStatistics))
}

i = 0
while (i < numExchanges) {
val exchange = exchanges(i)
var k = 0
while (k < numExchanges) {
val exchange = exchanges(k)
val rdd =
exchange.preparePostShuffleRDD(shuffleDependencies(i), partitionStartIndices)
exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices)
newPostShuffleRDDs.put(exchange, rdd)

i += 1
k += 1
}

// Finally, we set postShuffleRDDs and estimated.
Expand Down
150 changes: 116 additions & 34 deletions sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.Accumulators
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext}
import org.apache.spark.storage.{StorageLevel, RDDBlockId}

private case class BigData(s: String)

class CachedTableSuite extends QueryTest with SharedSQLContext {
class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext {
import testImplicits._

def rddIdOf(tableName: String): Int = {
Expand Down Expand Up @@ -375,53 +375,135 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"),
sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect())
sqlContext.uncacheTable("orderedTable")
sqlContext.dropTempTable("orderedTable")

// Set up two tables distributed in the same way. Try this with the data distributed into
// different number of partitions.
for (numPartitions <- 1 until 10 by 4) {
testData.repartition(numPartitions, $"key").registerTempTable("t1")
testData2.repartition(numPartitions, $"a").registerTempTable("t2")
withTempTable("t1", "t2") {
testData.repartition(numPartitions, $"key").registerTempTable("t1")
testData2.repartition(numPartitions, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")

// Joining them should result in no exchanges.
verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0)
checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"),
sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a"))

// Grouping on the partition key should result in no exchanges
verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0)
checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
sql("SELECT count(*) FROM testData GROUP BY key"))

sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
}
}

// Distribute the tables into non-matching number of partitions. Need to shuffle one side.
withTempTable("t1", "t2") {
testData.repartition(6, $"key").registerTempTable("t1")
testData2.repartition(3, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")

// Joining them should result in no exchanges.
verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0)
checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"),
sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a"))
val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
verifyNumExchanges(query, 1)
assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6)
checkAnswer(
query,
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
}

// Grouping on the partition key should result in no exchanges
verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0)
checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
sql("SELECT count(*) FROM testData GROUP BY key"))
// One side of join is not partitioned in the desired way. Need to shuffle one side.
withTempTable("t1", "t2") {
testData.repartition(6, $"value").registerTempTable("t1")
testData2.repartition(6, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")

val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
verifyNumExchanges(query, 1)
assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6)
checkAnswer(
query,
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
sqlContext.dropTempTable("t1")
sqlContext.dropTempTable("t2")
}

// Distribute the tables into non-matching number of partitions. Need to shuffle.
testData.repartition(6, $"key").registerTempTable("t1")
testData2.repartition(3, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")
withTempTable("t1", "t2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do if we do repartition(c1, c2).groupBy(c2, c1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we will not Exchange. But, let me double check it and add a test case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, and the corresponding join case would be interesting as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right now join will not work because different ordering of columns make us generate different hashcodes. But, ideally, we should avoid of shuffling for join in this case.

testData.repartition(6, $"value").registerTempTable("t1")
testData2.repartition(12, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")

verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2)
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
sqlContext.dropTempTable("t1")
sqlContext.dropTempTable("t2")
val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
verifyNumExchanges(query, 1)
assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 12)
checkAnswer(
query,
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
}

// One side of join is not partitioned in the desired way. Need to shuffle.
testData.repartition(6, $"value").registerTempTable("t1")
testData2.repartition(6, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")
// One side of join is not partitioned in the desired way. Since the number of partitions of
// the side that has already partitioned is smaller than the side that is not partitioned,
// we shuffle both side.
withTempTable("t1", "t2") {
testData.repartition(6, $"value").registerTempTable("t1")
testData2.repartition(3, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")

verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2)
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
sqlContext.dropTempTable("t1")
sqlContext.dropTempTable("t2")
val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
verifyNumExchanges(query, 2)
checkAnswer(
query,
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
}

// repartition's column ordering is different from group by column ordering.
// But they use the same set of columns.
withTempTable("t1") {
testData.repartition(6, $"value", $"key").registerTempTable("t1")
sqlContext.cacheTable("t1")

val query = sql("SELECT value, key from t1 group by key, value")
verifyNumExchanges(query, 0)
checkAnswer(
query,
testData.distinct().select($"value", $"key"))
sqlContext.uncacheTable("t1")
}

// repartition's column ordering is different from join condition's column ordering.
// We will still shuffle because hashcodes of a row depend on the column ordering.
// If we do not shuffle, we may actually partition two tables in totally two different way.
// See PartitioningSuite for more details.
withTempTable("t1", "t2") {
val df1 = testData
df1.repartition(6, $"value", $"key").registerTempTable("t1")
val df2 = testData2.select($"a", $"b".cast("string"))
df2.repartition(6, $"a", $"b").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")

val query =
sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b")
verifyNumExchanges(query, 1)
assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6)
checkAnswer(
query,
df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b"))
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nongli I added two test cases.

}
}