From cb831d8a854986a535b93af70268c40c80a9bcec Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 25 Jun 2024 14:44:12 -0700 Subject: [PATCH 1/2] fix --- .../analysis/DeduplicateRelations.scala | 18 +++++++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 20 +++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 3e4344f98bce..0fa11b9c4503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -255,12 +255,18 @@ object DeduplicateRelations extends Rule[LogicalPlan] { val newRightGroup = rewriteAttrs(c.rightGroup, rightAttrMap) val newLeftOrder = rewriteAttrs(c.leftOrder, leftAttrMap) val newRightOrder = rewriteAttrs(c.rightOrder, rightAttrMap) - val newKeyDes = c.keyDeserializer.asInstanceOf[UnresolvedDeserializer] - .copy(inputAttributes = newLeftGroup) - val newLeftDes = c.leftDeserializer.asInstanceOf[UnresolvedDeserializer] - .copy(inputAttributes = newLeftAttr) - val newRightDes = c.rightDeserializer.asInstanceOf[UnresolvedDeserializer] - .copy(inputAttributes = newRightAttr) + val newKeyDes = c.keyDeserializer match { + case u: UnresolvedDeserializer => u.copy(inputAttributes = newLeftGroup) + case e: Expression => e.withNewChildren(rewriteAttrs(e.children, leftAttrMap)) + } + val newLeftDes = c.leftDeserializer match { + case u: UnresolvedDeserializer => u.copy(inputAttributes = newLeftAttr) + case e: Expression => e.withNewChildren(rewriteAttrs(e.children, leftAttrMap)) + } + val newRightDes = c.rightDeserializer match { + case u: UnresolvedDeserializer => u.copy(inputAttributes = newRightAttr) + case e: Expression => e.withNewChildren(rewriteAttrs(e.children, rightAttrMap)) + } c.copy(keyDeserializer = newKeyDes, leftDeserializer = newLeftDes, rightDeserializer = newRightDes, leftGroup = newLeftGroup, rightGroup = newRightGroup, leftAttr = newLeftAttr, rightAttr = newRightAttr, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b939ed40c7db..c6d51a1661b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet +import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import scala.util.Random @@ -952,6 +953,25 @@ class DatasetSuite extends QueryTest assert(result2.length == 3) } + test("correctly handle when deserializer in cogroup is resolved in dedup relation") { + val lhs = spark.createDataFrame( + List(Row(123L)).asJava, + StructType(Seq(StructField("GROUPING_KEY", LongType))) + ) + val rhs = spark.createDataFrame( + List(Row(0L, 123L)).asJava, + StructType(Seq(StructField("ID", LongType), StructField("GROUPING_KEY", LongType))) + ) + + val lhsKV = lhs.groupByKey((r: Row) => r.getAs[Long]("GROUPING_KEY")) + val rhsKV = rhs.groupByKey((r: Row) => r.getAs[Long]("GROUPING_KEY")) + val cogrouped = lhsKV.cogroup(rhsKV)( + (a: Long, b: Iterator[Row], c: Iterator[Row]) => Iterator(0L) + ) + val joined = rhs.join(cogrouped, col("ID") === col("value"), "left") + checkAnswer(joined, Row(0L, 123L, 0L) :: Nil) + } + test("SPARK-34806: observation on datasets") { val namedObservation = Observation("named") val unnamedObservation = Observation() From 8c783579e2fb4bbe72029b47b12818f9fbe404a9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 26 Jun 2024 09:41:44 +0800 Subject: [PATCH 2/2] Update DatasetSuite.scala --- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c6d51a1661b9..fdb2ec30fdd2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -953,7 +953,7 @@ class DatasetSuite extends QueryTest assert(result2.length == 3) } - test("correctly handle when deserializer in cogroup is resolved in dedup relation") { + test("SPARK-48718: cogroup deserializer expr is resolved before dedup relation") { val lhs = spark.createDataFrame( List(Row(123L)).asJava, StructType(Seq(StructField("GROUPING_KEY", LongType)))