Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix v2 multi bucketed inner joins throw AssertionError
  • Loading branch information
ulysses-you committed Aug 9, 2024
commit f1f05fb110ee5eaa17d2e33e9d0151426c662942
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ case class EnsureRequirements(
private def createKeyGroupedShuffleSpec(
partitioning: Partitioning,
distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = {
def check(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = {
def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = {
val attributes = partitioning.expressions.flatMap(_.collectLeaves())
val clustering = distribution.clustering

Expand All @@ -636,11 +636,10 @@ case class EnsureRequirements(
}

partitioning match {
case p: KeyGroupedPartitioning => check(p)
case p: KeyGroupedPartitioning => tryCreate(p)
case PartitioningCollection(partitionings) =>
val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution))
assert(specs.forall(_.isEmpty) || specs.forall(_.isDefined))
specs.head
specs.filter(_.isDefined).map(_.get).headOption
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,27 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0)))
}

test("SPARK-49179: Fix v2 multi bucketed inner joins throw AssertionError") {
val cols = Array(
Column.create("id", LongType),
Column.create("name", StringType))
val buckets = Array(bucket(8, "id"))

withTable("t1", "t2", "t3") {
Seq("t1", "t2", "t3").foreach { t =>
createTable(t, cols, buckets)
sql(s"INSERT INTO testcat.ns.$t VALUES (1, 'aa'), (1, 'aa'), (2, 'bb'), (3, 'cc')")
}
val df = sql(
"""
|SELECT * FROM testcat.ns.t1
|JOIN testcat.ns.t2 ON t1.id = t2.id
|JOIN testcat.ns.t3 ON t1.id = t3.id
|""".stripMargin)
assert(collectShuffles(df.queryExecution.executedPlan).isEmpty)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also check the result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added checkAnswer

}
}

test("partitioned join: join with two partition keys and matching & sorted partitions") {
val items_partitions = Array(bucket(8, "id"), days("arrive_time"))
createTable(items, itemsColumns, items_partitions)
Expand Down