From d7c1f657e78a61a95333687d7b76fc392b6281e0 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 5 Oct 2021 03:45:44 +0900 Subject: [PATCH] Fix the DeduplicateRelations to copy dataset_id tag to avoid ambiguous self join. --- .../analysis/DeduplicateRelations.scala | 50 +++++-- .../spark/sql/DataFrameSelfJoinSuite.scala | 127 +++++++++++++++++- 2 files changed, 161 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 5dfed394f31e5..4ff1837ddc215 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -181,13 +181,16 @@ object DeduplicateRelations extends Rule[LogicalPlan] { case oldVersion: SerializeFromObject if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - Seq((oldVersion, oldVersion.copy( - serializer = oldVersion.serializer.map(_.newInstance())))) + val newVersion = oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance())) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) // Handle projects that create conflicting aliases. case oldVersion @ Project(projectList, _) if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => - Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList)))) + val newVersion = oldVersion.copy(projectList = newAliases(projectList)) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) // Handle projects that create conflicting outer references. case oldVersion @ Project(projectList, _) @@ -197,7 +200,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] { case o @ OuterReference(a) if conflictingAttributes.contains(a) => Alias(o, a.name)() case other => other } - Seq((oldVersion, oldVersion.copy(projectList = aliasedProjectList))) + val newVersion = oldVersion.copy(projectList = aliasedProjectList) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) // We don't need to search child plan recursively if the projectList of a Project // is only composed of Alias and doesn't contain any conflicting attributes. @@ -209,8 +214,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] { case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => - Seq((oldVersion, oldVersion.copy( - aggregateExpressions = newAliases(aggregateExpressions)))) + val newVersion = oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) // We don't search the child plan recursively for the same reason as the above Project. case _ @ Aggregate(_, aggregateExpressions, _) @@ -219,24 +225,34 @@ object DeduplicateRelations extends Rule[LogicalPlan] { case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + val newVersion = oldVersion.copy(output = output.map(_.newInstance())) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _) if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + val newVersion = oldVersion.copy(output = output.map(_.newInstance())) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) case oldVersion @ MapInPandas(_, output, _) if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + val newVersion = oldVersion.copy(output = output.map(_.newInstance())) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) case oldVersion @ AttachDistributedSequence(sequenceAttr, _) if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => - Seq((oldVersion, oldVersion.copy(sequenceAttr = sequenceAttr.newInstance()))) + val newVersion = oldVersion.copy(sequenceAttr = sequenceAttr.newInstance()) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) - Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput))) + val newVersion = oldVersion.copy(generatorOutput = newOutput) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) case oldVersion: Expand if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => @@ -248,16 +264,22 @@ object DeduplicateRelations extends Rule[LogicalPlan] { attr } } - Seq((oldVersion, oldVersion.copy(output = newOutput))) + val newVersion = oldVersion.copy(output = newOutput) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) case oldVersion @ Window(windowExpressions, _, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => - Seq((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))) + val newVersion = oldVersion.copy(windowExpressions = newAliases(windowExpressions)) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) case oldVersion @ ScriptTransformation(_, output, _, _) if AttributeSet(output).intersect(conflictingAttributes).nonEmpty => - Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + val newVersion = oldVersion.copy(output = output.map(_.newInstance())) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) case _ => plan.children.flatMap(collectConflictPlans) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 062404f412bb7..a0ddabcf76043 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.functions.{count, explode, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData.TestData +import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -344,4 +347,124 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assertAmbiguousSelfJoin(df1.join(df2).join(df5).join(df4).select(df2("b"))) } } + + test("SPARK-36874: DeduplicateRelations should copy dataset_id tag " + + "to avoid ambiguous self join") { + // Test for Project + val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") + val df2 = df1.filter($"value" === "A2") + assertAmbiguousSelfJoin(df1.join(df2, df1("key1") === df2("key2"))) + assertAmbiguousSelfJoin(df2.join(df1, df1("key1") === df2("key2"))) + + // Test for SerializeFromObject + val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF + val df4 = df3.filter($"_1" <=> 0) + assertAmbiguousSelfJoin(df3.join(df4, df3("_1") === df4("_2"))) + assertAmbiguousSelfJoin(df4.join(df3, df3("_1") === df4("_2"))) + + // Test For Aggregate + val df5 = df1.groupBy($"key1").agg(count($"value") as "count") + val df6 = df5.filter($"key1" > 0) + assertAmbiguousSelfJoin(df5.join(df6, df5("key1") === df6("count"))) + assertAmbiguousSelfJoin(df6.join(df5, df5("key1") === df6("count"))) + + // Test for MapInPandas + val mapInPandasUDF = PythonUDF("mapInPandasUDF", null, + StructType(Seq(StructField("x", LongType), StructField("y", LongType))), + Seq.empty, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + true) + val df7 = df1.mapInPandas(mapInPandasUDF) + val df8 = df7.filter($"x" > 0) + assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y"))) + assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y"))) + + // Test for FlatMapGroupsInPandas + val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null, + StructType(Seq(StructField("x", LongType), StructField("y", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true) + val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF) + val df10 = df9.filter($"x" > 0) + assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y"))) + assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y"))) + + // Test for FlatMapCoGroupsInPandas + val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null, + StructType(Seq(StructField("x", LongType), StructField("y", LongType))), + Seq.empty, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + true) + val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas( + df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF) + val df12 = df11.filter($"x" > 0) + assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y"))) + assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y"))) + + // Test for AttachDistributedSequence + val df13 = df1.withSequenceColumn("seq") + val df14 = df13.filter($"value" === "A2") + assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2"))) + assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2"))) + + // Test for Generate + // Ensure that the root of the plan is Generate + val df15 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList")) + .queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF + val df16 = df15.filter($"a" > 0) + assertAmbiguousSelfJoin(df15.join(df16, df15("a") === df16("col"))) + assertAmbiguousSelfJoin(df16.join(df15, df15("a") === df16("col"))) + + // Test for Expand + // Ensure that the root of the plan is Expand + val df17 = + Expand( + Seq(Seq($"key1".expr, $"key2".expr)), + Seq( + AttributeReference("x", IntegerType)(), + AttributeReference("y", IntegerType)()), + df1.queryExecution.logical).toDF + val df18 = df17.filter($"x" > 0) + assertAmbiguousSelfJoin(df17.join(df18, df17("x") === df18("y"))) + assertAmbiguousSelfJoin(df18.join(df17, df17("x") === df18("y"))) + + // Test for Window + val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b") + // Ensure that the root of the plan is Window + val df19 = WindowPlan( + Seq(Alias(dfWithTS("time").expr, "ts")()), + Seq(dfWithTS("a").expr), + Seq(SortOrder(dfWithTS("a").expr, Ascending)), + dfWithTS.queryExecution.logical).toDF + val df20 = df19.filter($"a" > 0) + assertAmbiguousSelfJoin(df19.join(df20, df19("a") === df20("b"))) + assertAmbiguousSelfJoin(df20.join(df19, df19("a") === df20("b"))) + + // Test for ScriptTransformation + val ioSchema = + ScriptInputOutputSchema( + Seq(("TOK_TABLEROWFORMATFIELD", ","), + ("TOK_TABLEROWFORMATCOLLITEMS", "#"), + ("TOK_TABLEROWFORMATMAPKEYS", "@"), + ("TOK_TABLEROWFORMATNULL", "null"), + ("TOK_TABLEROWFORMATLINES", "\n")), + Seq(("TOK_TABLEROWFORMATFIELD", ","), + ("TOK_TABLEROWFORMATCOLLITEMS", "#"), + ("TOK_TABLEROWFORMATMAPKEYS", "@"), + ("TOK_TABLEROWFORMATNULL", "null"), + ("TOK_TABLEROWFORMATLINES", "\n")), None, None, + List.empty, List.empty, None, None, false) + // Ensure that the root of the plan is ScriptTransformation + val df21 = ScriptTransformation( + "cat", + Seq( + AttributeReference("x", IntegerType)(), + AttributeReference("y", IntegerType)()), + df1.queryExecution.logical, + ioSchema).toDF + val df22 = df21.filter($"x" > 0) + assertAmbiguousSelfJoin(df21.join(df22, df21("x") === df22("y"))) + assertAmbiguousSelfJoin(df22.join(df21, df21("x") === df22("y"))) + } }