diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 54c0b84ff523..506237c9ea91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -678,22 +678,46 @@ object CoGroup { right: LogicalPlan): LogicalPlan = { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) + // SPARK-42132: The DeduplicateRelations rule would replace duplicate attributes + // in the right plan and rewrite rightGroup and rightAttr. But it would also rewrite + // leftGroup and leftAttr, which is wrong. Additionally, it does not rewrite rightDeserializer. + // Aliasing duplicate attributes in the right plan deduplicates them and stops + // DeduplicateRelations to do harm. + val duplicateAttributes = AttributeMap( + right.output.filter(left.output.contains).map(a => a -> Alias(a, a.name)()) + ) + + def dedup(attrs: Seq[Attribute]): Seq[NamedExpression] = + attrs.map(attr => duplicateAttributes.getOrElse(attr, attr)) + + // rightOrder is resolved against right plan, so deduplication not needed + val (dedupRightGroup, dedupRightAttr, dedupRight) = + if (duplicateAttributes.nonEmpty) { + ( + dedup(rightGroup).map(_.toAttribute), + dedup(rightAttr).map(_.toAttribute), + Project(dedup(right.output), right) + ) + } else { + (rightGroup, rightAttr, right) + } + val cogrouped = CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to // resolve the `keyDeserializer` based on either of them, here we pick the left one. UnresolvedDeserializer(encoderFor[K].deserializer, leftGroup), UnresolvedDeserializer(encoderFor[L].deserializer, leftAttr), - UnresolvedDeserializer(encoderFor[R].deserializer, rightAttr), + UnresolvedDeserializer(encoderFor[R].deserializer, dedupRightAttr), leftGroup, - rightGroup, + dedupRightGroup, leftAttr, - rightAttr, + dedupRightAttr, leftOrder, rightOrder, CatalystSerde.generateObjAttr[OUT], left, - right) + dedupRight) CatalystSerde.serialize[OUT](cogrouped) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0766dd2e7726..b3db996a9c3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -916,6 +916,20 @@ class DatasetSuite extends QueryTest } } + test("SPARK-42132: cogroup with same plan on both sides") { + val df = spark.range(3) + + val left_grouped_df = df.groupBy("id").as[Long, Long] + val right_grouped_df = df.groupBy("id").as[Long, Long] + + val cogroup_df = left_grouped_df.cogroup(right_grouped_df) { + case (key, left, right) => left.zip(right) + } + + val actual = cogroup_df.sort().collect() + assert(actual === Seq((0, 0), (1, 1), (2, 2))) + } + test("SPARK-34806: observation on datasets") { val namedObservation = Observation("named") val unnamedObservation = Observation()