diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 42dead7c2842..b6fd2bccc99f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} -import org.apache.spark.sql.types.{DataType, IntegerType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed @@ -39,25 +36,24 @@ sealed trait Distribution */ case object UnspecifiedDistribution extends Distribution -/** - * Represents a distribution that only has a single partition and all tuples of the dataset - * are co-located. - */ -case object AllTuples extends Distribution - /** * Represents data where tuples that share the same values for the `clustering` * [[Expression Expressions]] will be co-located. Based on the context, this * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. + * There is also another constraint, the `clustering` value contains null will be considered + * as a valid value if `nullKeysSensitive` == true. + * + * For examples: + * JOIN KEYS: values contains null will be considered as invalid values, which means + * the tuples could be in different partition. + * GROUP BY KEYS: values contains null will be considered as the valid value, which means + * the tuples should be in the same partition. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { - require( - clustering != Nil, - "The clustering expressions of a ClusteredDistribution should not be Nil. " + - "An AllTuples should be used to represent a distribution that only has " + - "a single partition.") -} +case class ClusteredDistribution( + clustering: Seq[Expression], + nullKeysSensitive: Boolean, + sortKeys: Seq[SortOrder] = Nil) extends Distribution /** * Represents data where tuples have been ordered according to the `ordering` @@ -76,144 +72,267 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } -sealed trait Partitioning { - /** Returns the number of partitions that the data is split across */ - val numPartitions: Int +/** + * To a child operator, a `Gap` represents what need to be done for satisfying its parent operator + * in the data distribution. + * + * NOTE: This trait and its inherits are not used by the physical operators directly, + */ +private[sql] sealed trait Gap - /** - * Returns true iff the guarantees made by this [[Partitioning]] are sufficient - * to satisfy the partitioning scheme mandated by the `required` [[Distribution]], - * i.e. the current dataset does not need to be re-partitioned for the `required` - * Distribution (it is possible that tuples within a partition need to be reorganized). - */ - def satisfies(required: Distribution): Boolean +/** + * Needn't do anything for the data distribution. + */ +private[sql] case object NoGap extends Gap - /** - * Returns true iff all distribution guarantees made by this partitioning can also be made - * for the `other` specified partitioning. - * For example, two [[HashPartitioning HashPartitioning]]s are - * only compatible if the `numPartitions` of them is the same. - */ - def compatibleWith(other: Partitioning): Boolean +/** + * Need to sort the data within the current partition. + * @param sortKeys the sorting keys + */ +private[sql] case class SortKeyWithinPartition(sortKeys: Seq[SortOrder]) extends Gap - /** Returns the expressions that are used to key the partitioning. */ - def keyExpressions: Seq[Expression] -} +/** + * Need a global sorting for the distribution according to the specified sorting keys. + * @param ordering the sorting keys + */ +private[sql] case class GlobalOrdering(ordering: Seq[SortOrder]) extends Gap -case class UnknownPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case _ => false - } +/** + * Repartition the data according to the new clustering expression, and it's possible that + * only a single partition needed, if in that cases, the clustering expression would be ignored. + * @param clustering the clustering keys + */ +private[sql] case class RepartitionKey( + clustering: Seq[Expression]) extends Gap - override def compatibleWith(other: Partitioning): Boolean = other match { - case UnknownPartitioning(_) => true - case _ => false - } +/** + * Repartition the data according to the the new clustering expression, and we also need to + * sort the data within the partition according to the clustering expression. + * Notice: The clustering expressions should be the same with the sort keys. + * @param clustering the clustering expression + * @param sortKeys the sorting keys, should be the same with clustering expression, but with + * sorting direction. + */ +private[sql] case class RepartitionKeyAndSort( + clustering: Seq[Expression], + sortKeys: Seq[SortOrder]) extends Gap - override def keyExpressions: Seq[Expression] = Nil -} -case object SinglePartition extends Partitioning { - val numPartitions = 1 +/** + * Represent the output data distribution for a physical operator. + * + * @param numPartitions + * @param clusterKeys + * @param sortKeys + * @param globalOrdered + * @param additionalNullClusterKeyGenerated + */ +sealed case class Partitioning( + /** the number of partitions that the data is split across */ + numPartitions: Option[Int] = None, + + /** the expressions that are used to key the partitioning. */ + clusterKeys: Seq[Expression] = Nil, + + /** the expression that are used to sort the data. */ + sortKeys: Seq[SortOrder] = Nil, + + /** work with `sortKeys` if the sorting cross or just within the partition. */ + globalOrdered: Boolean = false, - override def satisfies(required: Distribution): Boolean = true + /** to indicate if new null clustering key will be generated in THIS operator. */ + additionalNullClusterKeyGenerated: Boolean = false) { - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false + def withNumPartitions(num: Int): Partitioning = { + new Partitioning( + numPartitions = Some(num), + clusterKeys, + sortKeys, + globalOrdered, + additionalNullClusterKeyGenerated) } - override def keyExpressions: Seq[Expression] = Nil -} + def withClusterKeys(clusterKeys: Seq[Expression]): Partitioning = { + new Partitioning( + numPartitions, + clusterKeys = clusterKeys, + sortKeys, + globalOrdered, + additionalNullClusterKeyGenerated) + } -case object BroadcastPartitioning extends Partitioning { - val numPartitions = 1 + def withSortKeys(sortKeys: Seq[SortOrder], globalOrdering: Boolean = false): Partitioning = { + if (globalOrdering) { + new Partitioning( + numPartitions, + clusterKeys = sortKeys.map(_.child), + sortKeys = sortKeys, + globalOrdered = true, + additionalNullClusterKeyGenerated) + } else { + new Partitioning( + numPartitions, + clusterKeys, + sortKeys = sortKeys, + globalOrdered = false, + additionalNullClusterKeyGenerated) + } + } - override def satisfies(required: Distribution): Boolean = true + def withGlobalOrdered(globalOrdered: Boolean): Partitioning = { + new Partitioning( + numPartitions, + clusterKeys, + sortKeys, + globalOrdered = globalOrdered, + additionalNullClusterKeyGenerated) + } - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false + def withAdditionalNullClusterKeyGenerated(nullClusterKeyGenerated: Boolean): Partitioning = { + new Partitioning( + numPartitions, + clusterKeys, + sortKeys, + globalOrdered, + additionalNullClusterKeyGenerated = nullClusterKeyGenerated) } - override def keyExpressions: Seq[Expression] = Nil + /** + * Compute the gap between the required data distribution and the existed data distribution. + * + * @param required the required data distribution + * @return the gap that need to apply to the existed data. + */ + def gap(required: Distribution): Gap = required match { + case UnspecifiedDistribution => NoGap + case OrderedDistribution(ordering) if ordering == this.sortKeys && this.globalOrdered => NoGap + case OrderedDistribution(ordering) => GlobalOrdering(ordering) + case ClusteredDistribution(clustering, nullKeysSensitive, sortKeys) => + if (this.globalOrdered) { + // Child is a global ordering partition (clustered by range), definitely requires + // the repartitioning for a ClusteredDistribution + if (sortKeys.nonEmpty) { // required sorting + RepartitionKeyAndSort(clustering, sortKeys) + } else { + RepartitionKey(clustering) + } + } else { + // Child is not a global ordering partition, probably a Clustered Partitioning or + // UnspecifiedPartitioning + if (this.clusterKeys == clustering && clustering.nonEmpty) { // same distribution + if (nullKeysSensitive) { + // No NEW null cluster key generated from the child to be required + // e.g. In GROUP BY clause, even the clustering key is the same, however, + // if new null clustering key generated in child, we need to put all of the + // null clustering into a single partition. + if (this.additionalNullClusterKeyGenerated == false) { + // No null clustering key generated + if (sortKeys.isEmpty || sortKeys == this.sortKeys) { + // No sorting required or the sorting keys are the same with current partitioning + NoGap + } else { + // Sorting the data within the partition + SortKeyWithinPartition(sortKeys) + } + } else { + // child possible generate the null value for cluster keys, + // we need to repartitioning the data + if (sortKeys.nonEmpty) { // required sorting + RepartitionKeyAndSort(clustering, sortKeys) + } else { + RepartitionKey(clustering) + } + } + } else { + // Don't care if null cluster key generated from the child. + // E.g. In EQUAL-JOIN, we don't care about if null key should be in the same partition, + // As we always consider the null key would be not equal to each other. + if (sortKeys.isEmpty || sortKeys == this.sortKeys) { + NoGap + } else { + SortKeyWithinPartition(sortKeys) + } + } + } else { // not the same distribution + if (sortKeys.nonEmpty) { // required sorting + RepartitionKeyAndSort(clustering, sortKeys) + } else { + RepartitionKey(clustering) + } + } + } + } } -/** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. - */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression - with Partitioning { +// scalastyle:off +/******************************************************************/ +/* Helper utilities for the data partitioning */ +/******************************************************************/ +// scalastyle:on - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: DataType = IntegerType +object UnknownPartitioning extends Partitioning - private[this] lazy val clusteringSet = expressions.toSet +object HashPartition { + type ReturnType = Option[(Seq[Expression], Int)] // (ClusteringKey, NumberOfPartition) - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) - case _ => false + def apply(clustering: Seq[Expression]): Partitioning = { + UnknownPartitioning.withClusterKeys(clustering) } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case h: HashPartitioning if h == this => true - case _ => false + def unapply(part: Partitioning): ReturnType = { + if (part.globalOrdered == false && + part.clusterKeys.nonEmpty && + part.sortKeys.isEmpty) { + Some(part.clusterKeys, part.numPartitions.get) + } else { + None + } } +} - override def keyExpressions: Seq[Expression] = expressions - - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") +object RangePartition { + type ReturnType = Option[(Seq[SortOrder], Int)] // (Seq[SortOrder], NumberOfPartition) + def apply(ordering: Seq[SortOrder]): Partitioning = { + UnknownPartitioning.withSortKeys(ordering).withGlobalOrdered(true) + } + def unapply(part: Partitioning): ReturnType = { + if (part.globalOrdered && part.sortKeys.nonEmpty) { + Some(part.sortKeys, part.numPartitions.get) + } else { + None + } + } } -/** - * Represents a partitioning where rows are split across partitions based on some total ordering of - * the expressions specified in `ordering`. When data is partitioned in this manner the following - * two conditions are guaranteed to hold: - * - All row where the expressions in `ordering` evaluate to the same values will be in the same - * partition. - * - Each partition will have a `min` and `max` row, relative to the given ordering. All rows - * that are in between `min` and `max` in this `ordering` will reside in this partition. - * - * This class extends expression primarily so that transformations over expression will descend - * into its child. - */ -case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) - extends Expression - with Partitioning { - - override def children: Seq[SortOrder] = ordering - override def nullable: Boolean = false - override def dataType: DataType = IntegerType - - private[this] lazy val clusteringSet = ordering.map(_.child).toSet - - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case OrderedDistribution(requiredOrdering) => - val minSize = Seq(requiredOrdering.size, ordering.size).min - requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) - case _ => false +object HashPartitionWithSort { + // (Clustering Keys, Seq[SortOrder], NumberOfPartition) + type ReturnType = Option[(Seq[Expression], Seq[SortOrder], Int)] + + def apply(clustering: Seq[Expression], sortKeys: Seq[SortOrder]): Partitioning = { + UnknownPartitioning.withClusterKeys(clustering).withSortKeys(sortKeys) } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case r: RangePartitioning if r == this => true - case _ => false + def unapply(part: Partitioning): ReturnType = { + if (part.globalOrdered == false && + part.clusterKeys.nonEmpty && + part.sortKeys.nonEmpty) { + Some(part.clusterKeys, part.sortKeys, part.numPartitions.get) + } else { + None + } } +} - override def keyExpressions: Seq[Expression] = ordering.map(_.child) +object SinglePartition { + def apply(): Partitioning = { + UnknownPartitioning.withClusterKeys(Nil).withNumPartitions(1) + } - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + def unapply(part: Partitioning): Option[Int] = + if (part.numPartitions.get == 1 || part.clusterKeys.isEmpty) { + Some(1) + } else { + None + } } + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index c046dbf4dc2c..83507b05ced1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} import org.apache.spark.sql.catalyst.plans.physical._ /* Implicit conversions */ @@ -25,11 +26,12 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ class DistributionSuite extends SparkFunSuite { - protected def checkSatisfied( + protected def checkGap( inputPartitioning: Partitioning, requiredDistribution: Distribution, - satisfied: Boolean) { - if (inputPartitioning.satisfies(requiredDistribution) != satisfied) { + expectedGap: Gap) { + val gap = inputPartitioning.gap(requiredDistribution) + if (expectedGap != gap) { fail( s""" |== Input Partitioning == @@ -37,138 +39,172 @@ class DistributionSuite extends SparkFunSuite { |== Required Distribution == |$requiredDistribution |== Does input partitioning satisfy required distribution? == - |Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)} + |Expected $expectedGap got $gap """.stripMargin) } } - test("HashPartitioning is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + test("UnspecifiedDistribution is the required distribution") { + checkGap( + HashPartition(Seq('a, 'b, 'c)), UnspecifiedDistribution, - true) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), - true) - - checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), - true) - - checkSatisfied( - SinglePartition, - ClusteredDistribution(Seq('a, 'b, 'c)), - true) - - checkSatisfied( - SinglePartition, - OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - true) - - // Cases which need an exchange between two data properties. - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('b, 'c)), - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('d, 'e)), - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - AllTuples, - false) - - checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) - - checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) - - // TODO: We should check functional dependencies - /* - checkSatisfied( - ClusteredDistribution(Seq('b)), - ClusteredDistribution(Seq('b + 1)), - true) - */ - } + NoGap) + + checkGap( + SinglePartition(), + UnspecifiedDistribution, + NoGap) - test("RangePartitioning is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + checkGap( + RangePartition(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), UnspecifiedDistribution, - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.asc)), - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'b, 'a)), - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('b, 'c, 'a, 'd)), - true) - - // Cases which need an exchange between two data properties. - // TODO: We can have an optimization to first sort the dataset - // by a.asc and then sort b, and c in a partition. This optimization - // should tradeoff the benefit of a less number of Exchange operators - // and the parallelism. - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), - false) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('b.asc, 'a.asc)), - false) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b)), - false) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd)), - false) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - AllTuples, - false) + NoGap) + + checkGap( + HashPartition(Seq('a, 'b, 'c)), + UnspecifiedDistribution, + NoGap) + + checkGap( + HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false), + UnspecifiedDistribution, + NoGap) + } + + test("ClusteredDistribution is the required distribution") { + checkGap( + UnknownPartitioning, + ClusteredDistribution(Seq('a, 'b, 'c), true), + RepartitionKey(Seq('a, 'b, 'c))) + + checkGap( + HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(true), + ClusteredDistribution(Seq('a, 'b, 'c), nullKeysSensitive = true), + RepartitionKey(Seq('a, 'b, 'c))) + + checkGap( + HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false), + ClusteredDistribution(Seq('a, 'b, 'c), nullKeysSensitive = true), + NoGap) + + checkGap( + HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false), + ClusteredDistribution(Seq('a, 'b, 'c), nullKeysSensitive = false), + NoGap) + + checkGap( + HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false), + ClusteredDistribution(Seq('a, 'b, 'c), + nullKeysSensitive = true, + Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + SortKeyWithinPartition(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))) + + checkGap( + HashPartition(Seq('a, 'b, 'c)).withAdditionalNullClusterKeyGenerated(false), + ClusteredDistribution(Seq('a, 'b, 'c), + nullKeysSensitive = false, + Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + SortKeyWithinPartition(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))) + + checkGap( + HashPartition(Seq('a, 'b)).withAdditionalNullClusterKeyGenerated(false), + ClusteredDistribution(Seq('a, 'b), + nullKeysSensitive = false, + Seq(SortOrder('b, Ascending), SortOrder('a, Ascending))), + SortKeyWithinPartition(Seq(SortOrder('b, Ascending), SortOrder('a, Ascending)))) + + checkGap( + HashPartition(Seq('a, 'b)).withAdditionalNullClusterKeyGenerated(false), + ClusteredDistribution(Seq('b, 'a), + nullKeysSensitive = false, + Seq(SortOrder('a, Ascending), SortOrder('b, Ascending))), + RepartitionKeyAndSort(Seq('b, 'a), Seq(SortOrder('a, Ascending), SortOrder('b, Ascending)))) + + checkGap( + HashPartitionWithSort(Seq('b, 'c), Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + ClusteredDistribution(Seq('b, 'c), nullKeysSensitive = false), + NoGap) + + checkGap( + HashPartitionWithSort(Seq('b, 'c), Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + ClusteredDistribution(Seq('b, 'c), + nullKeysSensitive = false, + Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + NoGap) + + checkGap( + HashPartitionWithSort(Seq('b, 'c), Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + ClusteredDistribution(Seq('c, 'd), + nullKeysSensitive = false, + Seq(SortOrder('e, Ascending), SortOrder('f, Ascending))), + RepartitionKeyAndSort(Seq('c, 'd), Seq(SortOrder('e, Ascending), SortOrder('f, Ascending)))) + + checkGap( + HashPartitionWithSort(Seq('b, 'c), Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))) + .withAdditionalNullClusterKeyGenerated(true), + ClusteredDistribution(Seq('b, 'c), nullKeysSensitive = true), + RepartitionKey(Seq('b, 'c))) + + checkGap( + HashPartitionWithSort(Seq('b, 'c), + Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))) + .withAdditionalNullClusterKeyGenerated(false), + ClusteredDistribution(Seq('b, 'c), nullKeysSensitive = true), + NoGap) + + checkGap( + HashPartitionWithSort(Seq('c, 'b), + Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))) + .withAdditionalNullClusterKeyGenerated(false), + ClusteredDistribution(Seq('d, 'e), nullKeysSensitive = true), + RepartitionKey(Seq('d, 'e))) + + checkGap( + HashPartitionWithSort(Seq('b, 'c), + Seq(SortOrder('b, Ascending), + SortOrder('c, Ascending))).withAdditionalNullClusterKeyGenerated(false), + ClusteredDistribution(Seq('b, 'c), + nullKeysSensitive = true, + Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + NoGap) + } + + test("RangePartitioning is the required distribution") { + checkGap( + UnknownPartitioning, + OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))) + + checkGap( + SinglePartition(), + OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))) + + checkGap( + HashPartition(Seq('b, 'c)).withAdditionalNullClusterKeyGenerated(false), + OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))) + + checkGap( + HashPartition(Seq('b, 'c)), + OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))) + + checkGap( + HashPartitionWithSort(Seq('b, 'c), + Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))) + .withAdditionalNullClusterKeyGenerated(false), + OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))) + + checkGap( + RangePartition(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + NoGap) + + checkGap( + RangePartition(Seq(SortOrder('c, Ascending), SortOrder('b, Ascending))), + OrderedDistribution(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending))), + GlobalOrdering(Seq(SortOrder('b, Ascending), SortOrder('c, Ascending)))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 6e8a5ef18ab6..3bc126e45eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -44,15 +44,17 @@ case class Aggregate( child: SparkPlan) extends UnaryNode { + override def outputPartitioning: Partitioning = if (partial) { + child.outputPartitioning + } else { + HashPartition(groupingExpressions) + } + override def requiredChildDistribution: List[Distribution] = { if (partial) { UnspecifiedDistribution :: Nil } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } + ClusteredDistribution(groupingExpressions, true) :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index edc64a03335d..4bc6663ad4b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -41,14 +41,11 @@ import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEn @DeveloperApi case class Exchange( newPartitioning: Partitioning, - newOrdering: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning - override def outputOrdering: Seq[SortOrder] = newOrdering - override def output: Seq[Attribute] = child.output /** @@ -117,18 +114,28 @@ case class Exchange( } } - private val keyOrdering = { - if (newOrdering.nonEmpty) { - val key = newPartitioning.keyExpressions - val boundOrdering = newOrdering.map { o => - val ordinal = key.indexOf(o.child) - if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") - o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) + private def createShuffleRDD(expressions: Seq[Expression], numPartitions: Int) = { + val keySchema = expressions.map(_.dataType).toArray + val valueSchema = child.output.map(_.dataType).toArray + val serializer = getSerializer(keySchema, valueSchema, numPartitions) + val part = new HashPartitioner(numPartitions) + + val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + iter.map(r => (hashExpressions(r).copy(), r.copy())) } - new RowOrdering(boundOrdering) } else { - null // Ordering will not be used + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[InternalRow, InternalRow]() + iter.map(r => mutablePair.update(hashExpressions(r), r)) + } } + val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) + shuffled.setSerializer(serializer) + + shuffled.map(_._2) } @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf @@ -162,29 +169,10 @@ case class Exchange( protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { newPartitioning match { - case HashPartitioning(expressions, numPartitions) => - val keySchema = expressions.map(_.dataType).toArray - val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, valueSchema, numPartitions) - val part = new HashPartitioner(numPartitions) - - val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - iter.map(r => (hashExpressions(r).copy(), r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[InternalRow, InternalRow]() - iter.map(r => mutablePair.update(hashExpressions(r), r)) - } - } - val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) - shuffled.setSerializer(serializer) - shuffled.map(_._2) + case HashPartition(expressions, numPartitions) => + createShuffleRDD(expressions, numPartitions) - case RangePartitioning(sortingExpressions, numPartitions) => + case RangePartition(sortingExpressions, numPartitions) => val keySchema = child.output.map(_.dataType).toArray val serializer = getSerializer(keySchema, null, numPartitions) @@ -214,7 +202,7 @@ case class Exchange( shuffled.setSerializer(serializer) shuffled.map(_._1) - case SinglePartition => + case SinglePartition(_) => val valueSchema = child.output.map(_.dataType).toArray val serializer = getSerializer(null, valueSchema, numPartitions = 1) val partitioner = new HashPartitioner(1) @@ -252,99 +240,65 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // True iff every child's outputPartitioning satisfies the corresponding - // required data distribution. - def meetsRequirements: Boolean = - operator.requiredChildDistribution.zip(operator.children).forall { - case (required, child) => - val valid = child.outputPartitioning.satisfies(required) - logDebug( - s"${if (valid) "Valid" else "Invalid"} distribution," + - s"required: $required current: ${child.outputPartitioning}") - valid - } - - // True iff any of the children are incorrectly sorted. - def needsAnySort: Boolean = - operator.requiredChildOrdering.zip(operator.children).exists { - case (required, child) => required.nonEmpty && required != child.outputOrdering - } - - // True iff outputPartitionings of children are compatible with each other. - // It is possible that every child satisfies its required data distribution - // but two children have incompatible outputPartitionings. For example, - // A dataset is range partitioned by "a.asc" (RangePartitioning) and another - // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two - // datasets are both clustered by "a", but these two outputPartitionings are not - // compatible. - // TODO: ASSUMES TRANSITIVITY? - def compatible: Boolean = - !operator.children - .map(_.outputPartitioning) - .sliding(2) - .map { - case Seq(a) => true - case Seq(a, b) => a.compatibleWith(b) - }.exists(!_) - // Adds Exchange or Sort operators as required def addOperatorsIfNecessary( - partitioning: Partitioning, - rowOrdering: Seq[SortOrder], + gap: Gap, child: SparkPlan): SparkPlan = { - val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering - val needsShuffle = child.outputPartitioning != partitioning - - val withShuffle = if (needsShuffle) { - Exchange(partitioning, Nil, child) - } else { - child - } - - val withSort = if (needSort) { + // add Sort Operator, which only for partition internal sorting, not global sorting + def addSortOperator(sortKeys: Seq[SortOrder], sparkplan: SparkPlan) + : SparkPlan = { if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, withShuffle) + ExternalSort(sortKeys, global = false, sparkplan) } else { - Sort(rowOrdering, global = false, withShuffle) + Sort(sortKeys, global = false, sparkplan) } - } else { - withShuffle } - withSort - } + def bindSortKeys(sortKeys: Seq[SortOrder], attributes: Seq[Attribute]) + : Seq[SortOrder] = { + val boundOrdering = sortKeys.map { o => + o.copy(child = BindReferences.bindReference(o.child, attributes)) + } - if (meetsRequirements && compatible && !needsAnySort) { - operator - } else { - // At least one child does not satisfies its required data distribution or - // at least one child's outputPartitioning is not compatible with another child's - // outputPartitioning. In this case, we need to add Exchange operators. - val requirements = - (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + boundOrdering + } - val fixedChildren = requirements.zipped.map { - case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) - case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + gap match { + case NoGap => child + case SortKeyWithinPartition(sortKeys) => addSortOperator(sortKeys, child) + case GlobalOrdering(ordering) => + Exchange(RangePartition(ordering).withNumPartitions(numPartitions), child) + case RepartitionKey(clustering) => + val partitions = if (clustering.isEmpty) 1 else numPartitions + Exchange(HashPartition(clustering).withNumPartitions(partitions), child) + case RepartitionKeyAndSort(clustering, sortKeys) => + // TODO ideally, we probably will be benefit from the sort-based shuffle + // when the clustering keys is identical with the sortKeys, because the data + // is partially sorted during the shuffling. + // There are 2 concerns we need to consider if we want to implement this: + // 1) Detect if it's the sort-based shuffle. + // 2) SparkSqlSerializer2 will reuse the row (MutableRow) in deserialization + // that's will cause problem in sorting by ShuffledRDD. + val num = if (clustering.isEmpty) 1 else numPartitions + val exchanged = Exchange(HashPartition(clustering).withNumPartitions(num), child) + addSortOperator( + bindSortKeys(sortKeys, exchanged.output), + exchanged) + } + } - case (UnspecifiedDistribution, Seq(), child) => - child - case (UnspecifiedDistribution, rowOrdering, child) => - if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, child) - } else { - Sort(rowOrdering, global = false, child) - } + val gaps = operator.requiredChildDistribution.zip(operator.children).map { + case (requiredDistribution, child) => child.outputPartitioning.gap(requiredDistribution) + } - case (dist, ordering, _) => - sys.error(s"Don't know how to ensure $dist with ordering $ordering") + if (gaps.exists(_ != NoGap)) { + val fixedChildren = gaps.zip(operator.children).map { case (gap, child) => + addOperatorsIfNecessary(gap, child) } operator.withNewChildren(fixedChildren) + } else { + operator } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 42a0c1be4f69..ad4c15c09780 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -40,7 +40,7 @@ case class Expand( // The GroupExpressions can output data with arbitrary partitioning, so set it // as UNKNOWN partitioning - override def outputPartitioning: Partitioning = UnknownPartitioning(0) + override def outputPartitioning: Partitioning = UnknownPartitioning protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 44930f82b53a..53504f0d8c7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -53,15 +53,17 @@ case class GeneratedAggregate( child: SparkPlan) extends UnaryNode { + override def outputPartitioning: Partitioning = if (partial) { + child.outputPartitioning + } else { + HashPartition(groupingExpressions) + } + override def requiredChildDistribution: Seq[Distribution] = if (partial) { UnspecifiedDistribution :: Nil } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } + ClusteredDistribution(groupingExpressions, true) :: Nil } override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 7739a9f949c7..3c8425451482 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -67,18 +67,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ - def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! + def outputPartitioning: Partitioning = UnknownPartitioning /** Specifies any partition requirements on the input data for this operator. */ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) - /** Specifies how data is ordered in each partition. */ - def outputOrdering: Seq[SortOrder] = Nil - - /** Specifies sort order for each partition requirements on the input data for this operator. */ - def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) - /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute * after adding query plan information to created RDDs for visualization. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5daf86d81758..191d05e57c05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -336,7 +336,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.RepartitionByExpression(expressions, child) => execution.Exchange( - HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil + HashPartition(expressions).withNumPartitions(numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index fd6f1d7ae125..5c01a1bb35e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -21,7 +21,7 @@ import java.util import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.util.collection.CompactBuffer /** @@ -40,33 +40,21 @@ case class Window( override def output: Seq[Attribute] = (projectList ++ windowExpression).map(_.toAttribute) - override def requiredChildDistribution: Seq[Distribution] = - if (windowSpec.partitionSpec.isEmpty) { - // This operator will be very expensive. - AllTuples :: Nil - } else { - ClusteredDistribution(windowSpec.partitionSpec) :: Nil - } - - // Since window functions are adding columns to the input rows, the child's outputPartitioning - // is preserved. - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + override def requiredChildDistribution: Seq[Distribution] = { // The required child ordering has two parts. // The first part is the expressions in the partition specification. // We add these expressions to the required ordering to make sure input rows are grouped // based on the partition specification. So, we only need to process a single partition // at a time. - // The second part is the expressions specified in the ORDER BY cluase. + // The second part is the expressions specified in the ORDER BY clause. // Basically, we first use sort to group rows based on partition specifications and then sort // Rows in a group based on the order specification. - (windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec) :: Nil + + val sortKeys = (windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec) + ClusteredDistribution(windowSpec.partitionSpec, true, sortKeys) :: Nil } - // Since window functions basically add columns to input rows, this operator - // will not change the ordering of input rows. - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def outputPartitioning: Partitioning = HashPartition(windowSpec.partitionSpec) case class ComputedWindow( unbound: WindowExpression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 647c4ab5cb65..6fdf36f38475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -42,8 +42,6 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends val reusableProjection = buildProjection() iter.map(reusableProjection) } - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -59,8 +57,6 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -123,7 +119,7 @@ case class Limit(limit: Int, child: SparkPlan) private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = SinglePartition + override def outputPartitioning: Partitioning = SinglePartition() override def executeCollect(): Array[Row] = child.executeTake(limit) @@ -162,7 +158,8 @@ case class TakeOrderedAndProject( override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = SinglePartition + override def outputPartitioning: Partitioning = + SinglePartition().withSortKeys(sortOrder).withGlobalOrdered(true) private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) @@ -182,8 +179,6 @@ case class TakeOrderedAndProject( // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) - - override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -210,7 +205,11 @@ case class Sort( override def output: Seq[Attribute] = child.output - override def outputOrdering: Seq[SortOrder] = sortOrder + override def outputPartitioning: Partitioning = if (global) { + RangePartition(sortOrder) + } else { + child.outputPartitioning.withSortKeys(sortOrder) + } } /** @@ -242,7 +241,11 @@ case class ExternalSort( override def output: Seq[Attribute] = child.output - override def outputOrdering: Seq[SortOrder] = sortOrder + override def outputPartitioning: Partitioning = if (global) { + RangePartition(sortOrder) + } else { + child.outputPartitioning.withSortKeys(sortOrder) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index e41538ec1fc1..eeaf5a29fa9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -47,13 +47,13 @@ case class HashOuterJoin( override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case FullOuter => left.outputPartitioning.withAdditionalNullClusterKeyGenerated(true) case x => throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, false) :: ClusteredDistribution(rightKeys, false) :: Nil override def output: Seq[Attribute] = { joinType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 20d74270afb4..48b1b774bf2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -38,7 +38,7 @@ case class LeftSemiJoinHash( override val buildSide: BuildSide = BuildRight override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, false) :: ClusteredDistribution(rightKeys, false) :: Nil override def output: Seq[Attribute] = left.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5439e10a60b2..06d22bbc0c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -41,7 +41,7 @@ case class ShuffledHashJoin( override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, false) :: ClusteredDistribution(rightKeys, false) :: Nil protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 2abe65a71813..8645ec7e144d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -39,19 +39,17 @@ case class SortMergeJoin( override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = { + left.outputPartitioning.withSortKeys(requiredOrders(leftKeys)) + } override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, false, requiredOrders(leftKeys)) :: + ClusteredDistribution(rightKeys, false, requiredOrders(rightKeys)) :: Nil // this is to manually construct an ordering that can be used to compare keys from both sides private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3dd24130af81..cec3e50900b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -30,6 +30,70 @@ import org.apache.spark.sql.{Row, SQLConf, execution} class PlannerSuite extends SparkFunSuite { + test("multiway full outer join") { + val planned = testData + .join(testData2, testData("key") === testData2("a"), "outer") + .join(testData3, testData("key") === testData3("a"), "outer") + .queryExecution.executedPlan + val exchanges = planned.collect { case n: Exchange => n } + + // testData testData2 testData3 + // \(shuffle: key) /(shuffle: a) /(shuffle:a) + // \ / + // | + // result + assert(exchanges.size === 3) + } + + test("full outer join followed by aggregation") { + val planned = testData + .join(testData2, testData("key") === testData2("a"), "outer") // join key testData('key) + .groupBy(testData("key")).agg(testData("key"), count("a")) // group by key testData('key) + .queryExecution.executedPlan + val exchanges = planned.collect { case n: Exchange => n } + + // testData testData2 + // \(shuffle:key) /(shuffle:a) + // \ / + // | (shuffle:key) + // result + assert(exchanges.size === 3) + } + + test("left outer join followed by aggregation") { + val planned = testData + .join(testData2, testData("key") === testData2("a"), "left") // join key testData('key) + .groupBy(testData("key")).agg(testData("key"), count("a")) // group by key testData('key) + .queryExecution.executedPlan + val exchanges = planned.collect { case n: Exchange => n } + + // testData testData2 + // \(shuffle:key) /(shuffle:a) + // \ / + // | (key) <--- partial aggregation (no shuffle) + // | (key) <--- final aggregation (no shuffle) + // result + assert(exchanges.size === 2) + } + + test("full outer join followed by left outer join and aggregation") { + val planned = testData + .join(testData2, testData("key") === testData2("a"), "outer") // join key testData('key) + .join(testData3, testData("key") === testData3("a"), "left") + .groupBy(testData("key")).agg(testData("key"), count(lit(1))) // group by key testData('key) + .queryExecution.executedPlan + val exchanges = planned.collect { case n: Exchange => n } + + // testData testData2 testData3 + // \(shuffle: key) /(shuffle: a) /(shuffle:a) + // \ (null key generated) / + // | + // | <--- partial aggregation + // | (shuffle:key) as the null key generated for the left side + // result <--- final aggregation + assert(exchanges.size === 4) + } + test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head diff --git a/sql/hive/src/test/resources/golden/multi outer joins-0-aee6404ddbbd9e6e93bd13b4de5b548c b/sql/hive/src/test/resources/golden/multi outer joins-0-aee6404ddbbd9e6e93bd13b4de5b548c new file mode 100644 index 000000000000..269ae550a495 --- /dev/null +++ b/sql/hive/src/test/resources/golden/multi outer joins-0-aee6404ddbbd9e6e93bd13b4de5b548c @@ -0,0 +1,30 @@ +val_98 val_97 val_95 NULL +val_98 val_97 val_95 NULL +val_98 val_97 val_95 NULL +val_98 val_97 val_95 NULL +val_98 val_97 val_95 NULL +val_98 val_97 val_95 NULL +val_98 val_97 val_95 NULL +val_98 val_97 val_95 NULL +val_97 val_96 NULL val_92 +val_97 val_96 NULL val_92 +val_96 val_95 NULL NULL +val_96 val_95 NULL NULL +val_95 NULL val_92 val_90 +val_95 NULL val_92 val_90 +val_95 NULL val_92 val_90 +val_95 NULL val_92 val_90 +val_95 NULL val_92 val_90 +val_95 NULL val_92 val_90 +val_92 NULL NULL val_87 +val_90 NULL val_87 val_85 +val_90 NULL val_87 val_85 +val_90 NULL val_87 val_85 +val_9 val_8 NULL val_4 +val_87 val_86 val_84 val_82 +val_87 val_86 val_84 val_82 +val_86 val_85 val_83 NULL +val_86 val_85 val_83 NULL +val_85 val_84 val_82 val_80 +val_85 val_84 val_82 val_80 +val_84 val_83 NULL NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4cdba03b2702..8b04a0ef8545 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -85,6 +85,15 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + createQueryTest("multi outer joins", + """SELECT a.value as c1, b.value as c2, c.value as c3, d.value as c4 + |FROM src a + | full outer join src b on a.key=b.key + 1 + | full outer join src c on a.key=c.key+3 + | full outer join src d on a.key=d.key+5 + |order by c1 desc, c2 desc, c3 desc, c4 desc + |limit 30""".stripMargin) + createQueryTest("insert table with generator with column name", """ | CREATE TABLE gen_tmp (key Int);