Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,16 @@ object DeduplicateRelations extends Rule[LogicalPlan] {

case oldVersion: SerializeFromObject
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(
serializer = oldVersion.serializer.map(_.newInstance()))))
val newVersion = oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

// Handle projects that create conflicting aliases.
case oldVersion @ Project(projectList, _)
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList))))
val newVersion = oldVersion.copy(projectList = newAliases(projectList))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

// Handle projects that create conflicting outer references.
case oldVersion @ Project(projectList, _)
Expand All @@ -197,7 +200,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
case o @ OuterReference(a) if conflictingAttributes.contains(a) => Alias(o, a.name)()
case other => other
}
Seq((oldVersion, oldVersion.copy(projectList = aliasedProjectList)))
val newVersion = oldVersion.copy(projectList = aliasedProjectList)
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

// We don't need to search child plan recursively if the projectList of a Project
// is only composed of Alias and doesn't contain any conflicting attributes.
Expand All @@ -209,8 +214,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {

case oldVersion @ Aggregate(_, aggregateExpressions, _)
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(
aggregateExpressions = newAliases(aggregateExpressions))))
val newVersion = oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

// We don't search the child plan recursively for the same reason as the above Project.
case _ @ Aggregate(_, aggregateExpressions, _)
Expand All @@ -219,24 +225,34 @@ object DeduplicateRelations extends Rule[LogicalPlan] {

case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ MapInPandas(_, output, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())))
val newVersion = oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion: Generate
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput)))
val newVersion = oldVersion.copy(generatorOutput = newOutput)
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion: Expand
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
Expand All @@ -248,16 +264,22 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
attr
}
}
Seq((oldVersion, oldVersion.copy(output = newOutput)))
val newVersion = oldVersion.copy(output = newOutput)
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ Window(windowExpressions, _, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
Seq((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))))
val newVersion = oldVersion.copy(windowExpressions = newAliases(windowExpressions))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ ScriptTransformation(_, output, _, _)
if AttributeSet(output).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case _ => plan.children.flatMap(collectConflictPlans)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, sum}
import org.apache.spark.sql.functions.{count, explode, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}

class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
import testImplicits._
Expand Down Expand Up @@ -344,4 +347,124 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
assertAmbiguousSelfJoin(df1.join(df2).join(df5).join(df4).select(df2("b")))
}
}

test("SPARK-36874: DeduplicateRelations should copy dataset_id tag " +
"to avoid ambiguous self join") {
// Test for Project
val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value")
val df2 = df1.filter($"value" === "A2")
assertAmbiguousSelfJoin(df1.join(df2, df1("key1") === df2("key2")))
assertAmbiguousSelfJoin(df2.join(df1, df1("key1") === df2("key2")))

// Test for SerializeFromObject
val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF
val df4 = df3.filter($"_1" <=> 0)
assertAmbiguousSelfJoin(df3.join(df4, df3("_1") === df4("_2")))
assertAmbiguousSelfJoin(df4.join(df3, df3("_1") === df4("_2")))

// Test For Aggregate
val df5 = df1.groupBy($"key1").agg(count($"value") as "count")
val df6 = df5.filter($"key1" > 0)
assertAmbiguousSelfJoin(df5.join(df6, df5("key1") === df6("count")))
assertAmbiguousSelfJoin(df6.join(df5, df5("key1") === df6("count")))

// Test for MapInPandas
val mapInPandasUDF = PythonUDF("mapInPandasUDF", null,
StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
Seq.empty,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
true)
val df7 = df1.mapInPandas(mapInPandasUDF)
val df8 = df7.filter($"x" > 0)
assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y")))
assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y")))

// Test for FlatMapGroupsInPandas
val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null,
StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
Seq.empty,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
true)
val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF)
val df10 = df9.filter($"x" > 0)
assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y")))
assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y")))

// Test for FlatMapCoGroupsInPandas
val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null,
StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
Seq.empty,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
true)
val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas(
df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
val df12 = df11.filter($"x" > 0)
assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y")))
assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))

// Test for AttachDistributedSequence
val df13 = df1.withSequenceColumn("seq")
val df14 = df13.filter($"value" === "A2")
assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2")))
assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2")))

// Test for Generate
// Ensure that the root of the plan is Generate
val df15 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList"))
.queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF
val df16 = df15.filter($"a" > 0)
assertAmbiguousSelfJoin(df15.join(df16, df15("a") === df16("col")))
assertAmbiguousSelfJoin(df16.join(df15, df15("a") === df16("col")))

// Test for Expand
// Ensure that the root of the plan is Expand
val df17 =
Expand(
Seq(Seq($"key1".expr, $"key2".expr)),
Seq(
AttributeReference("x", IntegerType)(),
AttributeReference("y", IntegerType)()),
df1.queryExecution.logical).toDF
val df18 = df17.filter($"x" > 0)
assertAmbiguousSelfJoin(df17.join(df18, df17("x") === df18("y")))
assertAmbiguousSelfJoin(df18.join(df17, df17("x") === df18("y")))

// Test for Window
val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b")
// Ensure that the root of the plan is Window
val df19 = WindowPlan(
Seq(Alias(dfWithTS("time").expr, "ts")()),
Seq(dfWithTS("a").expr),
Seq(SortOrder(dfWithTS("a").expr, Ascending)),
dfWithTS.queryExecution.logical).toDF
val df20 = df19.filter($"a" > 0)
assertAmbiguousSelfJoin(df19.join(df20, df19("a") === df20("b")))
assertAmbiguousSelfJoin(df20.join(df19, df19("a") === df20("b")))

// Test for ScriptTransformation
val ioSchema =
ScriptInputOutputSchema(
Seq(("TOK_TABLEROWFORMATFIELD", ","),
("TOK_TABLEROWFORMATCOLLITEMS", "#"),
("TOK_TABLEROWFORMATMAPKEYS", "@"),
("TOK_TABLEROWFORMATNULL", "null"),
("TOK_TABLEROWFORMATLINES", "\n")),
Seq(("TOK_TABLEROWFORMATFIELD", ","),
("TOK_TABLEROWFORMATCOLLITEMS", "#"),
("TOK_TABLEROWFORMATMAPKEYS", "@"),
("TOK_TABLEROWFORMATNULL", "null"),
("TOK_TABLEROWFORMATLINES", "\n")), None, None,
List.empty, List.empty, None, None, false)
// Ensure that the root of the plan is ScriptTransformation
val df21 = ScriptTransformation(
"cat",
Seq(
AttributeReference("x", IntegerType)(),
AttributeReference("y", IntegerType)()),
df1.queryExecution.logical,
ioSchema).toDF
val df22 = df21.filter($"x" > 0)
assertAmbiguousSelfJoin(df21.join(df22, df21("x") === df22("y")))
assertAmbiguousSelfJoin(df22.join(df21, df21("x") === df22("y")))
}
}