Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeSet, NamedExpression, OuterReference, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, OuterReference, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
Expand Down Expand Up @@ -220,7 +220,42 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewChildren.rewriteAttrs(attrMap)
def rewriteAttrs[T <: Expression](
exprs: Seq[T],
attrMap: Map[Attribute, Attribute]): Seq[T] = {
exprs.map { expr =>
expr.transformWithPruning(_.containsPattern(ATTRIBUTE_REFERENCE)) {
case a: AttributeReference => attrMap.getOrElse(a, a)
}.asInstanceOf[T]
}
}

planWithNewChildren match {
// TODO (SPARK-44754): we should handle all special cases here.
case c: CoGroup =>
// SPARK-43781: CoGroup is a special case, `rewriteAttrs` will incorrectly update
// some fields that do not need to be updated. We need to update the output
// attributes of CoGroup manually.
val leftAttrMap = attrMap.filter(a => c.left.output.contains(a._2))
val rightAttrMap = attrMap.filter(a => c.right.output.contains(a._2))
val newLeftAttr = rewriteAttrs(c.leftAttr, leftAttrMap)
val newRightAttr = rewriteAttrs(c.rightAttr, rightAttrMap)
val newLeftGroup = rewriteAttrs(c.leftGroup, leftAttrMap)
val newRightGroup = rewriteAttrs(c.rightGroup, rightAttrMap)
val newLeftOrder = rewriteAttrs(c.leftOrder, leftAttrMap)
val newRightOrder = rewriteAttrs(c.rightOrder, rightAttrMap)
val newKeyDes = c.keyDeserializer.asInstanceOf[UnresolvedDeserializer]
.copy(inputAttributes = newLeftGroup)
val newLeftDes = c.leftDeserializer.asInstanceOf[UnresolvedDeserializer]
.copy(inputAttributes = newLeftAttr)
val newRightDes = c.rightDeserializer.asInstanceOf[UnresolvedDeserializer]
.copy(inputAttributes = newRightAttr)
c.copy(keyDeserializer = newKeyDes, leftDeserializer = newLeftDes,
rightDeserializer = newRightDes, leftGroup = newLeftGroup,
rightGroup = newRightGroup, leftAttr = newLeftAttr, rightAttr = newRightAttr,
leftOrder = newLeftOrder, rightOrder = newRightOrder)
case _ => planWithNewChildren.rewriteAttrs(attrMap)
}
}
} else {
planWithNewSubquery.withNewChildren(newChildren.toSeq)
Expand Down
26 changes: 26 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,32 @@ class DatasetSuite extends QueryTest
}
}

test("SPARK-43781: cogroup two datasets derived from the same source") {
val inputType = StructType(Array(StructField("id", LongType, false),
StructField("type", StringType, false)))
val keyType = StructType(Array(StructField("id", LongType, false)))

val inputRows = new java.util.ArrayList[Row]()
inputRows.add(Row(1L, "foo"))
inputRows.add(Row(1L, "bar"))
inputRows.add(Row(2L, "foo"))
val input = spark.createDataFrame(inputRows, inputType)
val fooGroups = input.filter("type = 'foo'").groupBy("id").as(ExpressionEncoder(keyType),
ExpressionEncoder(inputType))
val barGroups = input.filter("type = 'bar'").groupBy("id").as(ExpressionEncoder(keyType),
ExpressionEncoder(inputType))

val result = fooGroups.cogroup(barGroups) { case (row, iterator, iterator1) =>
iterator.toSeq ++ iterator1.toSeq
}(ExpressionEncoder(inputType)).collect()
assert(result.length == 3)

val result2 = fooGroups.cogroupSorted(barGroups)($"id")($"id") {
case (row, iterator, iterator1) => iterator.toSeq ++ iterator1.toSeq
}(ExpressionEncoder(inputType)).collect()
assert(result2.length == 3)
}

test("SPARK-34806: observation on datasets") {
val namedObservation = Observation("named")
val unnamedObservation = Observation()
Expand Down