diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 39dda9a13dad..153cdb5c69a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeSet, NamedExpression, OuterReference, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, OuterReference, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -220,7 +220,42 @@ object DeduplicateRelations extends Rule[LogicalPlan] { if (attrMap.isEmpty) { planWithNewChildren } else { - planWithNewChildren.rewriteAttrs(attrMap) + def rewriteAttrs[T <: Expression]( + exprs: Seq[T], + attrMap: Map[Attribute, Attribute]): Seq[T] = { + exprs.map { expr => + expr.transformWithPruning(_.containsPattern(ATTRIBUTE_REFERENCE)) { + case a: AttributeReference => attrMap.getOrElse(a, a) + }.asInstanceOf[T] + } + } + + planWithNewChildren match { + // TODO (SPARK-44754): we should handle all special cases here. + case c: CoGroup => + // SPARK-43781: CoGroup is a special case, `rewriteAttrs` will incorrectly update + // some fields that do not need to be updated. We need to update the output + // attributes of CoGroup manually. + val leftAttrMap = attrMap.filter(a => c.left.output.contains(a._2)) + val rightAttrMap = attrMap.filter(a => c.right.output.contains(a._2)) + val newLeftAttr = rewriteAttrs(c.leftAttr, leftAttrMap) + val newRightAttr = rewriteAttrs(c.rightAttr, rightAttrMap) + val newLeftGroup = rewriteAttrs(c.leftGroup, leftAttrMap) + val newRightGroup = rewriteAttrs(c.rightGroup, rightAttrMap) + val newLeftOrder = rewriteAttrs(c.leftOrder, leftAttrMap) + val newRightOrder = rewriteAttrs(c.rightOrder, rightAttrMap) + val newKeyDes = c.keyDeserializer.asInstanceOf[UnresolvedDeserializer] + .copy(inputAttributes = newLeftGroup) + val newLeftDes = c.leftDeserializer.asInstanceOf[UnresolvedDeserializer] + .copy(inputAttributes = newLeftAttr) + val newRightDes = c.rightDeserializer.asInstanceOf[UnresolvedDeserializer] + .copy(inputAttributes = newRightAttr) + c.copy(keyDeserializer = newKeyDes, leftDeserializer = newLeftDes, + rightDeserializer = newRightDes, leftGroup = newLeftGroup, + rightGroup = newRightGroup, leftAttr = newLeftAttr, rightAttr = newRightAttr, + leftOrder = newLeftOrder, rightOrder = newRightOrder) + case _ => planWithNewChildren.rewriteAttrs(attrMap) + } } } else { planWithNewSubquery.withNewChildren(newChildren.toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a021b049cf03..44b7c577bac7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -916,6 +916,32 @@ class DatasetSuite extends QueryTest } } + test("SPARK-43781: cogroup two datasets derived from the same source") { + val inputType = StructType(Array(StructField("id", LongType, false), + StructField("type", StringType, false))) + val keyType = StructType(Array(StructField("id", LongType, false))) + + val inputRows = new java.util.ArrayList[Row]() + inputRows.add(Row(1L, "foo")) + inputRows.add(Row(1L, "bar")) + inputRows.add(Row(2L, "foo")) + val input = spark.createDataFrame(inputRows, inputType) + val fooGroups = input.filter("type = 'foo'").groupBy("id").as(ExpressionEncoder(keyType), + ExpressionEncoder(inputType)) + val barGroups = input.filter("type = 'bar'").groupBy("id").as(ExpressionEncoder(keyType), + ExpressionEncoder(inputType)) + + val result = fooGroups.cogroup(barGroups) { case (row, iterator, iterator1) => + iterator.toSeq ++ iterator1.toSeq + }(ExpressionEncoder(inputType)).collect() + assert(result.length == 3) + + val result2 = fooGroups.cogroupSorted(barGroups)($"id")($"id") { + case (row, iterator, iterator1) => iterator.toSeq ++ iterator1.toSeq + }(ExpressionEncoder(inputType)).collect() + assert(result2.length == 3) + } + test("SPARK-34806: observation on datasets") { val namedObservation = Observation("named") val unnamedObservation = Observation()