Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeSet, NamedExpression, OuterReference, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, OuterReference, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
Expand Down Expand Up @@ -220,7 +220,42 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewChildren.rewriteAttrs(attrMap)
def rewriteAttrsMatchWithSubPlan[T <: Expression](
attrs: Seq[T],
attrMap: Map[Attribute, Attribute]): Seq[T] = {
attrs.map(attr => {
attr.transformWithPruning(_.containsPattern(ATTRIBUTE_REFERENCE)) {
case a: AttributeReference =>
attrMap.getOrElse(a, a)
}.asInstanceOf[T]
})
}

planWithNewChildren match {
case c: CoGroup =>
// SPARK-43781: CoGroup is a special case, `rewriteAttrs` will incorrectly update
// some fields that do not need to be updated. We need to update the output
// attributes of CoGroup manually.
val leftAttrMap = attrMap.filter(a => c.left.output.contains(a._2))
val rightAttrMap = attrMap.filter(a => c.right.output.contains(a._2))
val newLeftAttr = c.leftAttr.map(attr => attrMap.getOrElse(attr, attr))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the left-right attributes does not need to call rewriteAttrsMatchWithSubPlan?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just changed by review #41554 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the same apple to newRightAttr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and to left/right group?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the same apple to newRightAttr?

In fact it will be more safer, but I can't produce any negetive case.

and to left/right group?

what's meaning? Sorry I don't get it. The left/right group already apply rewriteAttrsMatchWithSubPlan.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, if you mean the leftAttr should same to rightAttr and left/right group. I would say yes. Maybe we should add rewriteAttrsMatchWithSubPlan to newLeftAttr. After code refactoring, invoke rewriteAttrsMatchWithSubPlan not a big deal. So I add it back.

val newRightAttr = rewriteAttrsMatchWithSubPlan(c.rightAttr, rightAttrMap)
val newLeftGroup = rewriteAttrsMatchWithSubPlan(c.leftGroup, leftAttrMap)
val newRightGroup = rewriteAttrsMatchWithSubPlan(c.rightGroup, rightAttrMap)
val newLeftOrder = rewriteAttrsMatchWithSubPlan(c.leftOrder, leftAttrMap)
val newRightOrder = rewriteAttrsMatchWithSubPlan(c.rightOrder, rightAttrMap)
val newKeyDes = c.keyDeserializer.asInstanceOf[UnresolvedDeserializer]
.copy(inputAttributes = newLeftGroup)
val newLeftDes = c.leftDeserializer.asInstanceOf[UnresolvedDeserializer]
.copy(inputAttributes = newLeftAttr)
val newRightDes = c.rightDeserializer.asInstanceOf[UnresolvedDeserializer]
.copy(inputAttributes = newRightAttr)
c.copy(keyDeserializer = newKeyDes, leftDeserializer = newLeftDes,
rightDeserializer = newRightDes, leftGroup = newLeftGroup,
rightGroup = newRightGroup, leftAttr = newLeftAttr, rightAttr = newRightAttr,
leftOrder = newLeftOrder, rightOrder = newRightOrder)
case _ => planWithNewChildren.rewriteAttrs(attrMap)
}
}
} else {
planWithNewSubquery.withNewChildren(newChildren.toSeq)
Expand Down
26 changes: 26 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,32 @@ class DatasetSuite extends QueryTest
}
}

test("SPARK-43781: cogroup two datasets derived from the same source") {
val inputType = StructType(Array(StructField("id", LongType, false),
StructField("type", StringType, false)))
val keyType = StructType(Array(StructField("id", LongType, false)))

val inputRows = new java.util.ArrayList[Row]()
inputRows.add(Row(1L, "foo"))
inputRows.add(Row(1L, "bar"))
inputRows.add(Row(2L, "foo"))
val input = spark.createDataFrame(inputRows, inputType)
val fooGroups = input.filter("type = 'foo'").groupBy("id").as(ExpressionEncoder(keyType),
ExpressionEncoder(inputType))
val barGroups = input.filter("type = 'bar'").groupBy("id").as(ExpressionEncoder(keyType),
ExpressionEncoder(inputType))

val result = fooGroups.cogroup(barGroups) { case (row, iterator, iterator1) =>
iterator.toSeq ++ iterator1.toSeq
}(ExpressionEncoder(inputType)).collect()
assert(result.length == 3)

val result2 = fooGroups.cogroupSorted(barGroups)($"id")($"id") {
case (row, iterator, iterator1) => iterator.toSeq ++ iterator1.toSeq
}(ExpressionEncoder(inputType)).collect()
assert(result2.length == 3)
}

test("SPARK-34806: observation on datasets") {
val namedObservation = Observation("named")
val unnamedObservation = Observation()
Expand Down