Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,50 +68,42 @@ case object AllTuples extends Distribution {
}
}

/**
* Represents data where tuples that share the same values for the `clustering`
* [[Expression Expressions]] will be co-located. Based on the context, this
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
*/
case class ClusteredDistribution(
clustering: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends Distribution {
abstract class ClusteredDistributionBase(exprs: Seq[Expression]) extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
exprs.nonEmpty,
s"The clustering expressions of a ${getClass.getSimpleName} should not be empty. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override def createPartitioning(numPartitions: Int): Partitioning = {
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"This ${getClass.getSimpleName} requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(clustering, numPartitions)
HashPartitioning(exprs, numPartitions)
}
}

/**
* Represents data where tuples have been clustered according to the hash of the given
* `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
* Represents data where tuples that share the same values for the `clustering`
* [[Expression Expressions]] will be co-located. Based on the context, this
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
*/
case class ClusteredDistribution(
clustering: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends ClusteredDistributionBase(clustering)

/**
* Represents data where tuples have been clustered according to the hash of the given expressions.
* The hash function is defined as [[HashPartitioning.partitionIdExpression]], so only
* [[HashPartitioning]] can satisfy this distribution.
*
* This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
* number of partitions, this distribution strictly requires which partition the tuple should be in.
*/
case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see any new tests in the DistributionSuite. I feel like issues likes this should have specified unit tests in DistributionSuite and shouldnt have to rely on StreamingJoinSuite.

require(
expressions != Nil,
"The expressions for hash of a HashPartitionedDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override def requiredNumPartitions: Option[Int] = None

override def createPartitioning(numPartitions: Int): Partitioning = {
HashPartitioning(expressions, numPartitions)
}
}
case class HashClusteredDistribution(
hashExprs: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends ClusteredDistributionBase(hashExprs)

/**
* Represents data where tuples have been ordered according to the `ordering`
Expand Down Expand Up @@ -170,8 +162,10 @@ trait Partitioning {
def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case AllTuples => numPartitions == 1
case _ => false
case _ => required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required)
}

protected def satisfies0(required: Distribution): Boolean = false
}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning
Expand All @@ -186,9 +180,8 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning
case object SinglePartition extends Partitioning {
val numPartitions = 1

override def satisfies(required: Distribution): Boolean = required match {
override def satisfies0(required: Distribution): Boolean = required match {
Copy link
Contributor

@tdas tdas Jun 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add docs to explain what is satisfies0 and how it different from satisfies?
Otherwise its quite confusing.
When does one override satisfies, and when does one override satisfies0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in the base class

case _: BroadcastDistribution => false
case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1
case _ => true
}
}
Expand All @@ -205,18 +198,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType

override def satisfies(required: Distribution): Boolean = {
super.satisfies(required) || {
required match {
case h: HashClusteredDistribution =>
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
case _ => false
}
override def satisfies0(required: Distribution): Boolean = {
required match {
case h: HashClusteredDistribution =>
expressions.length == h.hashExprs.length && expressions.zip(h.hashExprs).forall {
case (l, r) => l.semanticEquals(r)
}
case c: ClusteredDistribution =>
expressions.forall(x => c.clustering.exists(_.semanticEquals(x)))
case _ => false
}
}

Expand Down Expand Up @@ -246,17 +236,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType

override def satisfies(required: Distribution): Boolean = {
super.satisfies(required) || {
required match {
case OrderedDistribution(requiredOrdering) =>
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
case _ => false
}
override def satisfies0(required: Distribution): Boolean = {
required match {
case OrderedDistribution(requiredOrdering) =>
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering, _) =>
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,18 @@ class DataSourcePartitioning(

override val numPartitions: Int = partitioning.numPartitions()

override def satisfies(required: physical.Distribution): Boolean = {
super.satisfies(required) || {
required match {
case d: physical.ClusteredDistribution if isCandidate(d.clustering) =>
val attrs = d.clustering.map(_.asInstanceOf[Attribute])
partitioning.satisfy(
new ClusteredDistribution(attrs.map { a =>
val name = colNames.get(a)
assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output")
name.get
}.toArray))
override def satisfies0(required: physical.Distribution): Boolean = {
required match {
case d: physical.ClusteredDistribution if isCandidate(d.clustering) =>
val attrs = d.clustering.map(_.asInstanceOf[Attribute])
partitioning.satisfy(
new ClusteredDistribution(attrs.map { a =>
val name = colNames.get(a)
assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output")
name.get
}.toArray))

case _ => false
}
case _ => false
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
// these children may not be partitioned in the same way.
// Please see the comment in withCoordinator for more details.
val supportsDistribution = requiredChildDistributions.forall { dist =>
dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution]
dist.isInstanceOf[ClusteredDistributionBase]
}
children.length > 1 && supportsDistribution
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ case class StreamingSymmetricHashJoinExec(
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil
HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil

override def output: Seq[Attribute] = joinType match {
case _: InnerLike => left.output ++ right.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
AddData(input3, 5, 10),
CheckNewAnswer((5, 10, 5, 15, 5, 25)))
}

test("streaming join should require HashClusteredDistribution from children") {
val input1 = MemoryStream[Int]
val input2 = MemoryStream[Int]

val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b)
val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b)
val joined = df1.join(df2, Seq("a", "b")).select('a)

testStream(joined)(
AddData(input1, 1.to(1000): _*),
AddData(input2, 1.to(1000): _*),
CheckAnswer(1.to(1000): _*))
}
}


Expand Down