Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -678,22 +678,46 @@ object CoGroup {
right: LogicalPlan): LogicalPlan = {
require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))

// SPARK-42132: The DeduplicateRelations rule would replace duplicate attributes
// in the right plan and rewrite rightGroup and rightAttr. But it would also rewrite
// leftGroup and leftAttr, which is wrong. Additionally, it does not rewrite rightDeserializer.
// Aliasing duplicate attributes in the right plan deduplicates them and stops
// DeduplicateRelations to do harm.
val duplicateAttributes = AttributeMap(
right.output.filter(left.output.contains).map(a => a -> Alias(a, a.name)())
)

def dedup(attrs: Seq[Attribute]): Seq[NamedExpression] =
attrs.map(attr => duplicateAttributes.getOrElse(attr, attr))

// rightOrder is resolved against right plan, so deduplication not needed
val (dedupRightGroup, dedupRightAttr, dedupRight) =
if (duplicateAttributes.nonEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit weird to do dedup here. Can we update the DeduplicateRelations rule to handle CoGroup specially?

Copy link
Contributor Author

@EnricoMi EnricoMi Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really see how DeduplicateRelations can be modified to not rewrite all attributes of CoGroup.

In DeduplicateRelations.apply method renewDuplicatedRelations is called, which calls rewriteAttrs(attrMap) on the CoGroup, which rewrites all attributes.

If you are suggesting to add case cogroup @ CoGroup(...) => to

newPlan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(JOIN, LATERAL_JOIN, AS_OF_JOIN, INTERSECT, EXCEPT, UNION, COMMAND),
ruleId) {
case p: LogicalPlan if !p.childrenResolved => p
// To resolve duplicate expression IDs for Join.
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
// Resolve duplicate output for LateralJoin.
case j @ LateralJoin(left, right, _, _) if right.resolved && !j.duplicateResolved =>
j.copy(right = right.withNewPlan(dedupRight(left, right.plan)))
// Resolve duplicate output for AsOfJoin.
case j @ AsOfJoin(left, right, _, _, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
// intersect/except will be rewritten to join at the beginning of optimizer. Here we need to
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))
case e @ Except(left, right, _) if !e.duplicateResolved =>
e.copy(right = dedupRight(left, right))
// Only after we finish by-name resolution for Union
case u: Union if !u.byName && !u.duplicateResolved =>
// Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing
// feature in streaming.
val newChildren = u.children.foldRight(Seq.empty[LogicalPlan]) { (head, tail) =>
head +: tail.map {
case child if head.outputSet.intersect(child.outputSet).isEmpty =>
child
case child =>
val projectList = child.output.map { attr =>
Alias(attr, attr.name)()
}
Project(projectList, child)
}
}
u.copy(children = newChildren)
case merge: MergeIntoTable if !merge.duplicateResolved =>
merge.copy(sourceTable = dedupRight(merge.targetTable, merge.sourceTable))
}
}

then this won't work because all attributes of CoGroup have been rewritten at this point.

I'd appreciate some pointers or sketch of a solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hisoka-X is working on it: #41554

(
dedup(rightGroup).map(_.toAttribute),
dedup(rightAttr).map(_.toAttribute),
Project(dedup(right.output), right)
)
} else {
(rightGroup, rightAttr, right)
}

val cogrouped = CoGroup(
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
// The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
// resolve the `keyDeserializer` based on either of them, here we pick the left one.
UnresolvedDeserializer(encoderFor[K].deserializer, leftGroup),
UnresolvedDeserializer(encoderFor[L].deserializer, leftAttr),
UnresolvedDeserializer(encoderFor[R].deserializer, rightAttr),
UnresolvedDeserializer(encoderFor[R].deserializer, dedupRightAttr),
leftGroup,
rightGroup,
dedupRightGroup,
leftAttr,
rightAttr,
dedupRightAttr,
leftOrder,
rightOrder,
CatalystSerde.generateObjAttr[OUT],
left,
right)
dedupRight)
CatalystSerde.serialize[OUT](cogrouped)
}
}
Expand Down
14 changes: 14 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,20 @@ class DatasetSuite extends QueryTest
}
}

test("SPARK-42132: cogroup with same plan on both sides") {
val df = spark.range(3)

val left_grouped_df = df.groupBy("id").as[Long, Long]
val right_grouped_df = df.groupBy("id").as[Long, Long]

val cogroup_df = left_grouped_df.cogroup(right_grouped_df) {
case (key, left, right) => left.zip(right)
}

val actual = cogroup_df.sort().collect()
assert(actual === Seq((0, 0), (1, 1), (2, 2)))
}

test("SPARK-34806: observation on datasets") {
val namedObservation = Observation("named")
val unnamedObservation = Observation()
Expand Down