Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
StreamingSymmetricHashJoinExec should require HashClusteredPartitioni…
…ng from children
  • Loading branch information
cloud-fan committed Jun 19, 2018
commit d102da370babba06cfa1a349a98bbe56dda3d056
Original file line number Diff line number Diff line change
Expand Up @@ -68,50 +68,42 @@ case object AllTuples extends Distribution {
}
}

/**
* Represents data where tuples that share the same values for the `clustering`
* [[Expression Expressions]] will be co-located. Based on the context, this
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
*/
case class ClusteredDistribution(
clustering: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends Distribution {
abstract class ClusteredDistributionBase(exprs: Seq[Expression]) extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
exprs.nonEmpty,
s"The clustering expressions of a ${getClass.getSimpleName} should not be empty. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override def createPartitioning(numPartitions: Int): Partitioning = {
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"This ${getClass.getSimpleName} requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(clustering, numPartitions)
HashPartitioning(exprs, numPartitions)
}
}

/**
* Represents data where tuples have been clustered according to the hash of the given
* `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
* Represents data where tuples that share the same values for the `clustering`
* [[Expression Expressions]] will be co-located. Based on the context, this
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
*/
case class ClusteredDistribution(
clustering: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends ClusteredDistributionBase(clustering)

/**
* Represents data where tuples have been clustered according to the hash of the given expressions.
* The hash function is defined as [[HashPartitioning.partitionIdExpression]], so only
* [[HashPartitioning]] can satisfy this distribution.
*
* This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
* number of partitions, this distribution strictly requires which partition the tuple should be in.
*/
case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see any new tests in the DistributionSuite. I feel like issues likes this should have specified unit tests in DistributionSuite and shouldnt have to rely on StreamingJoinSuite.

require(
expressions != Nil,
"The expressions for hash of a HashPartitionedDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override def requiredNumPartitions: Option[Int] = None

override def createPartitioning(numPartitions: Int): Partitioning = {
HashPartitioning(expressions, numPartitions)
}
}
case class HashClusteredDistribution(
hashExprs: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends ClusteredDistributionBase(hashExprs)

/**
* Represents data where tuples have been ordered according to the `ordering`
Expand Down Expand Up @@ -207,15 +199,18 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)

override def satisfies(required: Distribution): Boolean = {
super.satisfies(required) || {
required match {
case h: HashClusteredDistribution =>
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
case _ => false
val satisfyNumPartitions = required.requiredNumPartitions.isEmpty ||
required.requiredNumPartitions.get == numPartitions
satisfyNumPartitions && {
required match {
case h: HashClusteredDistribution =>
expressions.length == h.hashExprs.length && expressions.zip(h.hashExprs).forall {
case (l, r) => l.semanticEquals(r)
}
case c: ClusteredDistribution =>
expressions.forall(x => c.clustering.exists(_.semanticEquals(x)))
case _ => false
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
// these children may not be partitioned in the same way.
// Please see the comment in withCoordinator for more details.
val supportsDistribution = requiredChildDistributions.forall { dist =>
dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution]
dist.isInstanceOf[ClusteredDistributionBase]
}
children.length > 1 && supportsDistribution
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ case class StreamingSymmetricHashJoinExec(
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil
HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil

override def output: Seq[Attribute] = joinType match {
case _: InnerLike => left.output ++ right.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
AddData(input3, 5, 10),
CheckNewAnswer((5, 10, 5, 15, 5, 25)))
}

test("streaming join should require HashClusteredDistribution from children") {
val input1 = MemoryStream[Int]
val input2 = MemoryStream[Int]

val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b)
val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b)
val joined = df1.join(df2, Seq("a", "b")).select('a)

testStream(joined)(
AddData(input1, 1.to(1000): _*),
AddData(input2, 1.to(1000): _*),
CheckAnswer(1.to(1000): _*))
}
}


Expand Down