From 501a80f310e3d855d015a0fb8b6e3e2e408f5f78 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 18 Jan 2023 09:08:43 +0100 Subject: [PATCH 1/9] Test and fix cogrouping same dataframe deduplication issue --- .../sql/catalyst/plans/logical/object.scala | 62 +++++++++++++++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 50 +++++++++++++++ .../exchange/EnsureRequirementsSuite.scala | 40 +++++++++++- 3 files changed, 145 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 54c0b84ff523..121b4f8b514b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -677,6 +677,28 @@ object CoGroup { left: LogicalPlan, right: LogicalPlan): LogicalPlan = { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) + val duplicateAttributes = right.output.filter(left.output.contains) + .map(a => a -> Alias(a, a.name)()).toMap + + def dedup(attrs: Seq[Attribute]): Seq[NamedExpression] = { + if (duplicateAttributes.nonEmpty) { + attrs.map(attr => duplicateAttributes.getOrElse(attr, attr)) + } else { + attrs + } + } + + val (dedupRightGroup, dedupRightAttr, dedupRightOrder, dedupRight) = + if (duplicateAttributes.nonEmpty) { + ( + dedup(rightGroup).map(_.toAttribute), + dedup(rightAttr).map(_.toAttribute), + rightOrder, + Project(dedup(right.output), right) + ) + } else { + (rightGroup, rightAttr, rightOrder, right) + } val cogrouped = CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], @@ -684,16 +706,16 @@ object CoGroup { // resolve the `keyDeserializer` based on either of them, here we pick the left one. UnresolvedDeserializer(encoderFor[K].deserializer, leftGroup), UnresolvedDeserializer(encoderFor[L].deserializer, leftAttr), - UnresolvedDeserializer(encoderFor[R].deserializer, rightAttr), + UnresolvedDeserializer(encoderFor[R].deserializer, dedupRightAttr), leftGroup, - rightGroup, + dedupRightGroup, leftAttr, - rightAttr, + dedupRightAttr, leftOrder, - rightOrder, + dedupRightOrder, CatalystSerde.generateObjAttr[OUT], left, - right) + dedupRight) CatalystSerde.serialize[OUT](cogrouped) } } @@ -716,6 +738,36 @@ case class CoGroup( outputObjAttr: Attribute, left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectProducer { + def rewriteAttrs2(attrMap: AttributeMap[Attribute]): LogicalPlan = { + // attributes rewritten in left / right children must only be rewritten + // in respective deserializer, group, attr, not both + // note: key deserializer refers to left side + val (leftAttrMap, rightAttrMap) = attrMap.partition(map => left.output.contains(map._2)) + val rewritten = super.rewriteAttrs(attrMap).asInstanceOf[CoGroup] + // revert attribute mappings to the wrong side + val revertLeftAttrMap = AttributeMap(leftAttrMap.map(m => (m._2, m._1))) + val revertRightAttrMap = AttributeMap(rightAttrMap.map(m => (m._2, m._1))) + + def revert(revertAttrMap: AttributeMap[Attribute])(attr: Attribute): Attribute = + revertAttrMap.get(attr).getOrElse(attr) + + CoGroup( + func, + rewritten.keyDeserializer, + rewritten.leftDeserializer, + rewritten.rightDeserializer, + rewritten.leftGroup.map(revert(revertRightAttrMap)), + rewritten.rightGroup.map(revert(revertLeftAttrMap)), + rewritten.leftAttr.map(revert(revertRightAttrMap)), + rewritten.rightAttr.map(revert(revertLeftAttrMap)), + rewritten.leftOrder, + rewritten.rightOrder, + rewritten.outputObjAttr, + rewritten.left, + rewritten.right + ) + } + override protected def withNewChildrenInternal( newLeft: LogicalPlan, newRight: LogicalPlan): CoGroup = copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6a90b3409d3a..d937e2889590 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2936,6 +2936,56 @@ class DataFrameSuite extends QueryTest parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`")) } + test("groupBy.as: cogroup two grouped dataframes") { + val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c") + .repartition($"a", $"b").sortWithinPartitions("a", "b").cache() + val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c") + .repartition($"a", $"b").sortWithinPartitions("a", "b").cache() + + val window = Window.partitionBy("b", "a") + + implicit val valueEncoder = RowEncoder(df1.schema) + + val df3 = df1.groupBy("a", "b").as[GroupByKey, Row] + .cogroup(df2.withColumn("c", sum($"c").over(window)).groupBy("a", "b").as[GroupByKey, Row]) { case (_, data1, data2) => + data1.zip(data2).map { p => + p._1.getInt(2) + p._2.getInt(2) + } + }.toDF + df3.show() + checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil) + } + + test("group.by: cogroup with same plan on both sides") { + val df = spark.range(3) + + val left_grouped_df = df.groupBy("id").as[Long, Long] + val right_grouped_df = df.groupBy("id").as[Long, Long] + + val cogroup_df = left_grouped_df.cogroup(right_grouped_df)( + (key: Long, left: Iterator[Long], right: Iterator[Long]) => left + ) + + val actual = cogroup_df.sort().collect() + assert(actual === Seq(0, 1, 2)) + } + + test("join") { + withSQLConf( + SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> "WARN", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = spark.range(3).select($"id", ($"id" * 10).as("day")) // Seq(1L, 2L, 3L).toDF("id") works just fine + val left_df = df.withColumn("side", lit("left")) + val right_df = df.withColumn("side", lit("right")) + + val join_df = left_df.join(right_df, Seq("id", "day")) + //val join_df = left_df.join(right_df, left_df("id") <=> right_df("id") && left_df("day") <=> right_df("day")) + + join_df.explain() + assert(false) + } + } + test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") val df2 = df1.filter($"value" === "A2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 09da1e1e7b01..c8d6f316ef47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -18,14 +18,15 @@ package org.apache.spark.sql.execution.exchange import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.connector.catalog.functions._ -import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{CoGroupExec, DummySparkPlan, SortExec, SparkPlan} import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec import org.apache.spark.sql.execution.window.WindowExec @@ -1160,6 +1161,41 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } + test("CoGroup") { + val intDeserializer = Encoders.INT.asInstanceOf[ExpressionEncoder[Int]].objDeserializer + val left = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(Seq( + years(exprA), bucket(4, exprB), days(exprC)), 4) + ) + val right = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(Seq( + years(exprA), bucket(4, exprB), days(exprC)), 4) + ) + + val cogroupExec = CoGroupExec( + (key: Any, left: Iterator[Any], right: Iterator[Any]) => left, + intDeserializer, + intDeserializer, + intDeserializer, + Seq(AttributeReference("key", IntegerType)()), + Seq(AttributeReference("key", IntegerType)()), + Seq(AttributeReference("key", IntegerType)(), AttributeReference("value", IntegerType)()), + Seq(AttributeReference("key", IntegerType)(), AttributeReference("value", IntegerType)()), + AttributeReference("value", IntegerType)(), + left, + right) + + val result = EnsureRequirements.apply(cogroupExec) + result match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprB, exprC)) + assert(rightKeys === Seq(exprB, exprA)) + case other => fail(other.toString) + } + } + def bucket(numBuckets: Int, expr: Expression): TransformExpression = { TransformExpression(BucketFunction, Seq(expr), Some(numBuckets)) } From b0f5ea653f8b8f960afd3710eb2448f57a188690 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 20 Jan 2023 12:01:38 +0100 Subject: [PATCH 2/9] Remove unrelated tests, improve comments --- .../sql/catalyst/plans/logical/object.scala | 34 ++------------ .../org/apache/spark/sql/DataFrameSuite.scala | 44 ++----------------- .../exchange/EnsureRequirementsSuite.scala | 40 +---------------- 3 files changed, 10 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 121b4f8b514b..07f96e6e1aec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -677,6 +677,10 @@ object CoGroup { left: LogicalPlan, right: LogicalPlan): LogicalPlan = { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) + + // SPARK-42132: The DeduplicateRelations rule would replace duplicate attributes + // in leftGroup and leftAttr as well, but not in rightDeserializer + // aliasing those attributes here deduplicates them as well val duplicateAttributes = right.output.filter(left.output.contains) .map(a => a -> Alias(a, a.name)()).toMap @@ -738,36 +742,6 @@ case class CoGroup( outputObjAttr: Attribute, left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectProducer { - def rewriteAttrs2(attrMap: AttributeMap[Attribute]): LogicalPlan = { - // attributes rewritten in left / right children must only be rewritten - // in respective deserializer, group, attr, not both - // note: key deserializer refers to left side - val (leftAttrMap, rightAttrMap) = attrMap.partition(map => left.output.contains(map._2)) - val rewritten = super.rewriteAttrs(attrMap).asInstanceOf[CoGroup] - // revert attribute mappings to the wrong side - val revertLeftAttrMap = AttributeMap(leftAttrMap.map(m => (m._2, m._1))) - val revertRightAttrMap = AttributeMap(rightAttrMap.map(m => (m._2, m._1))) - - def revert(revertAttrMap: AttributeMap[Attribute])(attr: Attribute): Attribute = - revertAttrMap.get(attr).getOrElse(attr) - - CoGroup( - func, - rewritten.keyDeserializer, - rewritten.leftDeserializer, - rewritten.rightDeserializer, - rewritten.leftGroup.map(revert(revertRightAttrMap)), - rewritten.rightGroup.map(revert(revertLeftAttrMap)), - rewritten.leftAttr.map(revert(revertRightAttrMap)), - rewritten.rightAttr.map(revert(revertLeftAttrMap)), - rewritten.leftOrder, - rewritten.rightOrder, - rewritten.outputObjAttr, - rewritten.left, - rewritten.right - ) - } - override protected def withNewChildrenInternal( newLeft: LogicalPlan, newRight: LogicalPlan): CoGroup = copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d937e2889590..6f6aadd8ce04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2936,56 +2936,20 @@ class DataFrameSuite extends QueryTest parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`")) } - test("groupBy.as: cogroup two grouped dataframes") { - val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c") - .repartition($"a", $"b").sortWithinPartitions("a", "b").cache() - val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c") - .repartition($"a", $"b").sortWithinPartitions("a", "b").cache() - - val window = Window.partitionBy("b", "a") - - implicit val valueEncoder = RowEncoder(df1.schema) - - val df3 = df1.groupBy("a", "b").as[GroupByKey, Row] - .cogroup(df2.withColumn("c", sum($"c").over(window)).groupBy("a", "b").as[GroupByKey, Row]) { case (_, data1, data2) => - data1.zip(data2).map { p => - p._1.getInt(2) + p._2.getInt(2) - } - }.toDF - df3.show() - checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil) - } - - test("group.by: cogroup with same plan on both sides") { + test("SPARK-42132: group.by: cogroup with same plan on both sides") { val df = spark.range(3) val left_grouped_df = df.groupBy("id").as[Long, Long] val right_grouped_df = df.groupBy("id").as[Long, Long] - val cogroup_df = left_grouped_df.cogroup(right_grouped_df)( - (key: Long, left: Iterator[Long], right: Iterator[Long]) => left - ) + val cogroup_df = left_grouped_df.cogroup(right_grouped_df) { + case (key, left, right) => left + } val actual = cogroup_df.sort().collect() assert(actual === Seq(0, 1, 2)) } - test("join") { - withSQLConf( - SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> "WARN", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df = spark.range(3).select($"id", ($"id" * 10).as("day")) // Seq(1L, 2L, 3L).toDF("id") works just fine - val left_df = df.withColumn("side", lit("left")) - val right_df = df.withColumn("side", lit("right")) - - val join_df = left_df.join(right_df, Seq("id", "day")) - //val join_df = left_df.join(right_df, left_df("id") <=> right_df("id") && left_df("day") <=> right_df("day")) - - join_df.explain() - assert(false) - } - } - test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") val df2 = df1.filter($"value" === "A2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index c8d6f316ef47..09da1e1e7b01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql.execution.exchange import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.Encoders -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.connector.catalog.functions._ -import org.apache.spark.sql.execution.{CoGroupExec, DummySparkPlan, SortExec, SparkPlan} +import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec import org.apache.spark.sql.execution.window.WindowExec @@ -1161,41 +1160,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test("CoGroup") { - val intDeserializer = Encoders.INT.asInstanceOf[ExpressionEncoder[Int]].objDeserializer - val left = DummySparkPlan( - outputPartitioning = KeyGroupedPartitioning(Seq( - years(exprA), bucket(4, exprB), days(exprC)), 4) - ) - val right = DummySparkPlan( - outputPartitioning = KeyGroupedPartitioning(Seq( - years(exprA), bucket(4, exprB), days(exprC)), 4) - ) - - val cogroupExec = CoGroupExec( - (key: Any, left: Iterator[Any], right: Iterator[Any]) => left, - intDeserializer, - intDeserializer, - intDeserializer, - Seq(AttributeReference("key", IntegerType)()), - Seq(AttributeReference("key", IntegerType)()), - Seq(AttributeReference("key", IntegerType)(), AttributeReference("value", IntegerType)()), - Seq(AttributeReference("key", IntegerType)(), AttributeReference("value", IntegerType)()), - AttributeReference("value", IntegerType)(), - left, - right) - - val result = EnsureRequirements.apply(cogroupExec) - result match { - case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), - SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => - assert(leftKeys === Seq(exprB, exprC)) - assert(rightKeys === Seq(exprB, exprA)) - case other => fail(other.toString) - } - } - def bucket(numBuckets: Int, expr: Expression): TransformExpression = { TransformExpression(BucketFunction, Seq(expr), Some(numBuckets)) } From eec3c98960d1d2c1844b17c4ddaeb82876044a23 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 25 Jan 2023 08:56:00 +0100 Subject: [PATCH 3/9] Rewrite right sort order --- .../spark/sql/catalyst/plans/logical/object.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 07f96e6e1aec..aacc637a55f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -681,15 +681,12 @@ object CoGroup { // SPARK-42132: The DeduplicateRelations rule would replace duplicate attributes // in leftGroup and leftAttr as well, but not in rightDeserializer // aliasing those attributes here deduplicates them as well - val duplicateAttributes = right.output.filter(left.output.contains) - .map(a => a -> Alias(a, a.name)()).toMap + val duplicateAttributes = AttributeMap( + right.output.filter(left.output.contains).map(a => a -> Alias(a, a.name)()) + ) def dedup(attrs: Seq[Attribute]): Seq[NamedExpression] = { - if (duplicateAttributes.nonEmpty) { - attrs.map(attr => duplicateAttributes.getOrElse(attr, attr)) - } else { - attrs - } + attrs.map(attr => duplicateAttributes.getOrElse(attr, attr)) } val (dedupRightGroup, dedupRightAttr, dedupRightOrder, dedupRight) = @@ -697,7 +694,9 @@ object CoGroup { ( dedup(rightGroup).map(_.toAttribute), dedup(rightAttr).map(_.toAttribute), - rightOrder, + rightOrder.map(_.transformDown { + case a: Attribute => duplicateAttributes.getOrElse(a, a) + }.asInstanceOf[SortOrder]), Project(dedup(right.output), right) ) } else { From 0bde873dc9bf955b908e0ddf10d45ea430b07c21 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 25 Jan 2023 10:01:36 +0100 Subject: [PATCH 4/9] Add test for sorted cogroup on duplicate reference --- .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6f6aadd8ce04..63b0c0543f3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2950,6 +2950,20 @@ class DataFrameSuite extends QueryTest assert(actual === Seq(0, 1, 2)) } + test("SPARK-42132: group.by: cogroupSorted with same plan on both sides") { + val df = spark.range(3) + + val left_grouped_df = df.groupBy("id").as[Long, Long] + val right_grouped_df = df.groupBy("id").as[Long, Long] + + val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"id")($"id") { + case (key, left, right) => left + } + + val actual = cogroup_df.sort().collect() + assert(actual === Seq(0, 1, 2)) + } + test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") val df2 = df1.filter($"value" === "A2") From 023c1993fc55a22c5e4435af6cc43c56d02745de Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 25 Jan 2023 10:02:00 +0100 Subject: [PATCH 5/9] Remove right order from deduplication --- .../spark/sql/catalyst/plans/logical/object.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index aacc637a55f7..9d47f4f005c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -689,18 +689,16 @@ object CoGroup { attrs.map(attr => duplicateAttributes.getOrElse(attr, attr)) } - val (dedupRightGroup, dedupRightAttr, dedupRightOrder, dedupRight) = + // rightOrder is resolved against right plan, so deduplication not needed + val (dedupRightGroup, dedupRightAttr, dedupRight) = if (duplicateAttributes.nonEmpty) { ( dedup(rightGroup).map(_.toAttribute), dedup(rightAttr).map(_.toAttribute), - rightOrder.map(_.transformDown { - case a: Attribute => duplicateAttributes.getOrElse(a, a) - }.asInstanceOf[SortOrder]), Project(dedup(right.output), right) ) } else { - (rightGroup, rightAttr, rightOrder, right) + (rightGroup, rightAttr, right) } val cogrouped = CoGroup( @@ -715,7 +713,7 @@ object CoGroup { leftAttr, dedupRightAttr, leftOrder, - dedupRightOrder, + rightOrder, CatalystSerde.generateObjAttr[OUT], left, dedupRight) From 442482089dbc54846d7c5e31460eaf1e976ade4e Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 25 Jan 2023 10:19:19 +0100 Subject: [PATCH 6/9] Finish test, rewording comments --- .../sql/catalyst/plans/logical/object.scala | 9 +++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 20 +++++++++++-------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 9d47f4f005c7..506237c9ea91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -679,15 +679,16 @@ object CoGroup { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) // SPARK-42132: The DeduplicateRelations rule would replace duplicate attributes - // in leftGroup and leftAttr as well, but not in rightDeserializer - // aliasing those attributes here deduplicates them as well + // in the right plan and rewrite rightGroup and rightAttr. But it would also rewrite + // leftGroup and leftAttr, which is wrong. Additionally, it does not rewrite rightDeserializer. + // Aliasing duplicate attributes in the right plan deduplicates them and stops + // DeduplicateRelations to do harm. val duplicateAttributes = AttributeMap( right.output.filter(left.output.contains).map(a => a -> Alias(a, a.name)()) ) - def dedup(attrs: Seq[Attribute]): Seq[NamedExpression] = { + def dedup(attrs: Seq[Attribute]): Seq[NamedExpression] = attrs.map(attr => duplicateAttributes.getOrElse(attr, attr)) - } // rightOrder is resolved against right plan, so deduplication not needed val (dedupRightGroup, dedupRightAttr, dedupRight) = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 63b0c0543f3a..c56c433d6c89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2943,25 +2943,29 @@ class DataFrameSuite extends QueryTest val right_grouped_df = df.groupBy("id").as[Long, Long] val cogroup_df = left_grouped_df.cogroup(right_grouped_df) { - case (key, left, right) => left + case (key, left, right) => left.zip(right) } val actual = cogroup_df.sort().collect() - assert(actual === Seq(0, 1, 2)) + assert(actual === Seq((0, 0), (1, 1), (2, 2))) } test("SPARK-42132: group.by: cogroupSorted with same plan on both sides") { - val df = spark.range(3) + val df = spark.range(3).join(spark.range(2).withColumnRenamed("id", "value")) - val left_grouped_df = df.groupBy("id").as[Long, Long] - val right_grouped_df = df.groupBy("id").as[Long, Long] + val left_grouped_df = df.groupBy("id").as[Long, (Long, Long)] + val right_grouped_df = df.groupBy("id").as[Long, (Long, Long)] - val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"id")($"id") { - case (key, left, right) => left + val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"value")($"value".desc) { + case (key, left, right) => left.zip(right) } val actual = cogroup_df.sort().collect() - assert(actual === Seq(0, 1, 2)) + assert(actual === Seq( + ((0, 0), (0, 1)), ((0, 1), (0, 0)), + ((1, 0), (1, 1)), ((1, 1), (1, 0)), + ((2, 0), (2, 1)), ((2, 1), (2, 0)) + )) } test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { From c8dfd9d0dc5af95667d32bae0469b0613a3aa887 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Mon, 30 Jan 2023 08:40:32 +0100 Subject: [PATCH 7/9] Move tests into DatasetSuite --- .../org/apache/spark/sql/DataFrameSuite.scala | 32 ------------------- .../org/apache/spark/sql/DatasetSuite.scala | 32 +++++++++++++++++++ 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c56c433d6c89..6a90b3409d3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2936,38 +2936,6 @@ class DataFrameSuite extends QueryTest parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`")) } - test("SPARK-42132: group.by: cogroup with same plan on both sides") { - val df = spark.range(3) - - val left_grouped_df = df.groupBy("id").as[Long, Long] - val right_grouped_df = df.groupBy("id").as[Long, Long] - - val cogroup_df = left_grouped_df.cogroup(right_grouped_df) { - case (key, left, right) => left.zip(right) - } - - val actual = cogroup_df.sort().collect() - assert(actual === Seq((0, 0), (1, 1), (2, 2))) - } - - test("SPARK-42132: group.by: cogroupSorted with same plan on both sides") { - val df = spark.range(3).join(spark.range(2).withColumnRenamed("id", "value")) - - val left_grouped_df = df.groupBy("id").as[Long, (Long, Long)] - val right_grouped_df = df.groupBy("id").as[Long, (Long, Long)] - - val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"value")($"value".desc) { - case (key, left, right) => left.zip(right) - } - - val actual = cogroup_df.sort().collect() - assert(actual === Seq( - ((0, 0), (0, 1)), ((0, 1), (0, 0)), - ((1, 0), (1, 1)), ((1, 1), (1, 0)), - ((2, 0), (2, 1)), ((2, 1), (2, 0)) - )) - } - test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") val df2 = df1.filter($"value" === "A2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0766dd2e7726..823d0805201e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -916,6 +916,38 @@ class DatasetSuite extends QueryTest } } + test("SPARK-42132: cogroup with same plan on both sides") { + val df = spark.range(3) + + val left_grouped_df = df.groupBy("id").as[Long, Long] + val right_grouped_df = df.groupBy("id").as[Long, Long] + + val cogroup_df = left_grouped_df.cogroup(right_grouped_df) { + case (key, left, right) => left.zip(right) + } + + val actual = cogroup_df.sort().collect() + assert(actual === Seq((0, 0), (1, 1), (2, 2))) + } + + test("SPARK-42132: cogroup with sorted with same plan on both sides") { + val df = spark.range(3).join(spark.range(2).withColumnRenamed("id", "value")) + + val left_grouped_df = df.groupBy("id").as[Long, (Long, Long)] + val right_grouped_df = df.groupBy("id").as[Long, (Long, Long)] + + val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"value")($"value".desc) { + case (key, left, right) => left.zip(right) + } + + val actual = cogroup_df.sort().collect() + assert(actual === Seq( + ((0, 0), (0, 1)), ((0, 1), (0, 0)), + ((1, 0), (1, 1)), ((1, 1), (1, 0)), + ((2, 0), (2, 1)), ((2, 1), (2, 0)) + )) + } + test("SPARK-34806: observation on datasets") { val namedObservation = Observation("named") val unnamedObservation = Observation() From 25fa7977e082b8ee66d32fca125813f6c4e70609 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 26 Jan 2023 17:10:15 +0100 Subject: [PATCH 8/9] Adjust tests for AppendColumns fix --- .../org/apache/spark/sql/DatasetSuite.scala | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 823d0805201e..7057559b4eb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -930,13 +930,31 @@ class DatasetSuite extends QueryTest assert(actual === Seq((0, 0), (1, 1), (2, 2))) } - test("SPARK-42132: cogroup with sorted with same plan on both sides") { - val df = spark.range(3).join(spark.range(2).withColumnRenamed("id", "value")) + test("SPARK-42132: cogroup with sorted and same plan on both sides") { + val df = spark.range(3).join(spark.range(2)).as[(Long, Long)] val left_grouped_df = df.groupBy("id").as[Long, (Long, Long)] val right_grouped_df = df.groupBy("id").as[Long, (Long, Long)] - val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"value")($"value".desc) { + val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"id")($"id".desc) { + case (key, left, right) => left.zip(right) + } + + val actual = cogroup_df.sort().collect() + assert(actual === Seq( + ((0, 0), (0, 1)), ((0, 1), (0, 0)), + ((1, 0), (1, 1)), ((1, 1), (1, 0)), + ((2, 0), (2, 1)), ((2, 1), (2, 0)) + )) + } + + test("SPARK-42132: cogroup groupby function with sorted and same plan on both sides") { + val df = spark.range(3).join(spark.range(2)).as[(Long, Long)] + + val left_grouped_df = df.groupByKey(_._1) + val right_grouped_df = df.groupByKey(_._1) + + val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"id")($"id".desc) { case (key, left, right) => left.zip(right) } From cc5773cdd2f91c05ca5129a627f8ffed92cb845f Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 26 Jan 2023 17:15:54 +0100 Subject: [PATCH 9/9] Remove cogroupSorted tests --- .../org/apache/spark/sql/DatasetSuite.scala | 36 ------------------- 1 file changed, 36 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7057559b4eb2..b3db996a9c3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -930,42 +930,6 @@ class DatasetSuite extends QueryTest assert(actual === Seq((0, 0), (1, 1), (2, 2))) } - test("SPARK-42132: cogroup with sorted and same plan on both sides") { - val df = spark.range(3).join(spark.range(2)).as[(Long, Long)] - - val left_grouped_df = df.groupBy("id").as[Long, (Long, Long)] - val right_grouped_df = df.groupBy("id").as[Long, (Long, Long)] - - val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"id")($"id".desc) { - case (key, left, right) => left.zip(right) - } - - val actual = cogroup_df.sort().collect() - assert(actual === Seq( - ((0, 0), (0, 1)), ((0, 1), (0, 0)), - ((1, 0), (1, 1)), ((1, 1), (1, 0)), - ((2, 0), (2, 1)), ((2, 1), (2, 0)) - )) - } - - test("SPARK-42132: cogroup groupby function with sorted and same plan on both sides") { - val df = spark.range(3).join(spark.range(2)).as[(Long, Long)] - - val left_grouped_df = df.groupByKey(_._1) - val right_grouped_df = df.groupByKey(_._1) - - val cogroup_df = left_grouped_df.cogroupSorted(right_grouped_df)($"id")($"id".desc) { - case (key, left, right) => left.zip(right) - } - - val actual = cogroup_df.sort().collect() - assert(actual === Seq( - ((0, 0), (0, 1)), ((0, 1), (0, 0)), - ((1, 0), (1, 1)), ((1, 1), (1, 0)), - ((2, 0), (2, 1)), ((2, 1), (2, 0)) - )) - } - test("SPARK-34806: observation on datasets") { val namedObservation = Observation("named") val unnamedObservation = Observation()