From f3dd17a5564ce5098aab112dda28fa02cd802af6 Mon Sep 17 00:00:00 2001 From: ashahid Date: Thu, 29 Feb 2024 15:20:02 -0800 Subject: [PATCH 01/10] SPARK-47217. bug fix for exception thrown in reused dataframes involving joins once the plan is de-duplicated. The fix involves using Dataset ID associated with the plans & attributes to attempt correct resolution --- .../analysis/ColumnResolutionHelper.scala | 42 ++++++++++ .../sql/catalyst/analysis/unresolved.scala | 41 ++++++++++ .../catalyst/plans/logical/LogicalPlan.scala | 4 +- .../scala/org/apache/spark/sql/Dataset.scala | 37 ++++++++- .../apache/spark/sql/JavaDataFrameSuite.java | 76 +++++++++++++++++++ .../spark/sql/DataFrameAsOfJoinSuite.scala | 20 +++++ .../spark/sql/DataFrameSelfJoinSuite.scala | 10 +++ 7 files changed, 226 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 8ea50e2ceb65..22361b0e73c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -134,6 +134,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { expr: Expression, resolveColumnByName: Seq[String] => Option[Expression], getAttrCandidates: () => Seq[Attribute], + resolveOnDatasetId: (Long, String) => Option[NamedExpression], throws: Boolean, includeLastResort: Boolean): Expression = { def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) { @@ -156,6 +157,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } matched(ordinal) + case u @ UnresolvedAttributeWithTag(attr, id) => + resolveOnDatasetId(id, attr.name).getOrElse(attr) + case u @ UnresolvedAttribute(nameParts) => val result = withPosition(u) { resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { @@ -452,6 +456,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { plan.resolve(nameParts, conf.resolver) }, getAttrCandidates = () => plan.output, + resolveOnDatasetId = (_, _) => None, throws = throws, includeLastResort = includeLastResort) } @@ -477,6 +482,43 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { assert(q.children.length == 1) q.children.head.output }, + + resolveOnDatasetId = (datasetid: Long, name: String) => { + def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[LogicalPlan] = { + var currentLp = lp + while(currentLp.children.size < 2) { + if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetid))) { + return Option(currentLp) + } else { + if (currentLp.children.size == 1) { + currentLp = currentLp.children.head + } else { + // leaf node + return None + } + } + } + None + } + + val binaryNodeOpt = q.collectFirst { + case bn: BinaryNode => bn + } + + val resolveOnAttribs = binaryNodeOpt match { + case Some(bn) => + val leftDefOpt = findUnaryNodeMatchingTagId(bn.left) + val rightDefOpt = findUnaryNodeMatchingTagId(bn.right) + (leftDefOpt, rightDefOpt) match { + case (None, Some(lp)) => lp.output + case (Some(lp), None) => lp.output + case _ => q.children.head.output + } + + case _ => q.children.head.output + } + AttributeSeq.fromNormalOutput(resolveOnAttribs).resolve(Seq(name), conf.resolver) + }, throws = true, includeLastResort = includeLastResort) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 7a3cc4bc8e83..397351e0c1fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -268,6 +268,47 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un } } +case class UnresolvedAttributeWithTag(attribute: Attribute, datasetId: Long) extends Attribute with + Unevaluable { + def name: String = attribute.name + + override def exprId: ExprId = throw new UnresolvedException("exprId") + + override def dataType: DataType = throw new UnresolvedException("dataType") + + override def nullable: Boolean = throw new UnresolvedException("nullable") + + override def qualifier: Seq[String] = throw new UnresolvedException("qualifier") + + override lazy val resolved = false + + override def newInstance(): UnresolvedAttributeWithTag = this + + override def withNullability(newNullability: Boolean): UnresolvedAttributeWithTag = this + + override def withQualifier(newQualifier: Seq[String]): UnresolvedAttributeWithTag = this + + override def withName(newName: String): UnresolvedAttributeWithTag = this + + override def withMetadata(newMetadata: Metadata): Attribute = this + + override def withExprId(newExprId: ExprId): UnresolvedAttributeWithTag = this + + override def withDataType(newType: DataType): Attribute = this + + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE) + + override def toString: String = s"'$name" + + override def sql: String = attribute.sql + + /** + * Returns true if this matches the token. This requires the attribute to only have one part in + * its name and that matches the given token in a case insensitive way. + */ + def equalsIgnoreCase(token: String): Boolean = token.equalsIgnoreCase(attribute.name) +} + object UnresolvedAttribute extends AttributeNameParser { /** * Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.'). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e1121d1f9026..a9b130c981ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable + import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -30,7 +32,6 @@ import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{MapType, StructType} - abstract class LogicalPlan extends QueryPlan[LogicalPlan] with AnalysisHelper @@ -199,6 +200,7 @@ object LogicalPlan { // to the old code path. private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id") private[spark] val IS_METADATA_COL = TreeNodeTag[Unit]("is_metadata_col") + private[spark] val DATASET_ID_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 189be1d6a30d..b767cc01f341 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId @@ -73,7 +73,7 @@ private[sql] object Dataset { val curId = new java.util.concurrent.atomic.AtomicLong() val DATASET_ID_KEY = "__dataset_id" val COL_POS_KEY = "__col_position" - val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id") + val DATASET_ID_TAG = LogicalPlan.DATASET_ID_TAG def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) @@ -1308,12 +1308,20 @@ class Dataset[T] private[sql]( case a: AttributeReference if logicalPlan.outputSet.contains(a) => val index = logicalPlan.output.indexWhere(_.exprId == a.exprId) joined.left.output(index) + + case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) => + UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY)) } + val rightAsOfExpr = rightAsOf.expr.transformUp { case a: AttributeReference if other.logicalPlan.outputSet.contains(a) => val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId) joined.right.output(index) + + case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) => + UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY)) } + withPlan { AsOfJoin( joined.left, joined.right, @@ -1576,7 +1584,30 @@ class Dataset[T] private[sql]( case other => other } - Project(untypedCols.map(_.named), logicalPlan) + val namedExprs = untypedCols.map(_.named) + val inputSet = logicalPlan.outputSet + val rectifiedNamedExprs = namedExprs.map(ne => ne match { + + case al: Alias if !al.references.subsetOf(inputSet) && + al.nonInheritableMetadataKeys.contains(Dataset.DATASET_ID_KEY) => + val unresolvedExpr = al.child.transformUp { + case attr: AttributeReference if !inputSet.contains(attr) => + UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) + } + val newAl = al.copy(child = unresolvedExpr, name = al.name)(exprId = al.exprId, + qualifier = al.qualifier, explicitMetadata = al.explicitMetadata, + nonInheritableMetadataKeys = al.nonInheritableMetadataKeys) + newAl.copyTagsFrom(al) + newAl + + case attr: Attribute if !inputSet.contains(attr) && + attr.metadata.contains(Dataset.DATASET_ID_KEY) => + UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) + + case _ => ne + + }) + Project(rectifiedNamedExprs, logicalPlan) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 26a19cbed1b9..f325e3c1a80e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -24,6 +24,7 @@ import java.math.BigInteger; import java.math.BigDecimal; +import org.apache.spark.sql.catalyst.plans.logical.Join; import scala.collection.Seq; import scala.jdk.javaapi.CollectionConverters; @@ -31,6 +32,11 @@ import com.google.common.primitives.Ints; import org.junit.jupiter.api.*; +import org.apache.spark.sql.catalyst.expressions.Alias; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; +import org.apache.spark.sql.catalyst.expressions.AttributeSet; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Column; @@ -540,4 +546,74 @@ public void testUDF() { .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assertions.assertArrayEquals(expected, result); } + + @Test + public void testDedupBehaviourOnProjection_SPARK_47217() { + //Create table1 DF + List table1Data = Arrays.asList( + RowFactory.create(1, 2, 3), + RowFactory.create(1, 2, 3)); + StructType table1Scema = new StructType() + .add("col11", DataTypes.IntegerType) + .add("col12", DataTypes.IntegerType) + .add("col13", DataTypes.IntegerType); + + Dataset table1 = spark.createDataFrame(table1Data, table1Scema); + + //Create table2 DataFrame + List table2Data = Arrays.asList(RowFactory.create(1, 2, 3)); + StructType table2Schema = new StructType() + .add("col21", DataTypes.IntegerType) + .add("col22", DataTypes.IntegerType) + .add("col23", DataTypes.IntegerType); + + Dataset table2 = spark.createDataFrame(table2Data, table2Schema); + + //Create table 3 DataFrame + List table3Data = Arrays.asList(RowFactory.create(1, 2, 3), RowFactory.create(1, 2, 3)); + StructType table3Schema = new StructType(). + add("col31", DataTypes.IntegerType). + add("col32", DataTypes.IntegerType). + add("col33", DataTypes.IntegerType); + + Dataset table3 = spark.createDataFrame(table3Data, table3Schema); + + //Perform left outer join for table2 + Dataset srcDf = table1.join( + table2, + table1.col("col11").equalTo(table2.col("col21")), + "left_outer").select( + table1.col("col11"), + table1.col("col12"), + table1.col("col13"), + table2.col("col22")); + + //Perform leftouter join for exchange table2(firstjoin) + srcDf = srcDf.join( + broadcast(table3), + srcDf.col("col12").equalTo( + table3.col("col32")), "left_outer"). + select(srcDf.col("col11"), srcDf.col("col12"), + srcDf.col("col13"), + table3.col("col33").as("col33_1")); + + //Perform left outer joinfor exchangeRateTable1 again(secondjoin) + Dataset temp = srcDf.join(broadcast(table3), + srcDf.col("col11").equalTo( + table3.col("col31")), "left_outer"); + + srcDf = temp.select(srcDf.col("col11"), srcDf.col("col12"), + srcDf.col("col13"), srcDf.col("col33_1"), + table3.col("col33").as("col33_2")); + + // verify optimized plan creation ok + srcDf.queryExecution().optimizedPlan(); + LogicalPlan lp = srcDf.queryExecution().analyzed(); + // verify attribute ref resolution is correct, i.e it resolves to the right leg of join + AttributeReference refToCheck = + (AttributeReference) ((Alias)((Project)lp).projectList().last()).child(); + AttributeSet compareSet = + ((Join) ((Project) lp).child()).right().outputSet(); + Assertions.assertTrue(compareSet.contains(refToCheck)); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala index 280eb095dc75..3db8d374c4bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters._ +import org.apache.spark.sql.catalyst.plans.logical.AsOfJoin import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession @@ -173,4 +174,23 @@ class DataFrameAsOfJoinSuite extends QueryTest ) ) } + + test("SPARK_47217: Dedup of relations can impact projected columns resolution -1") { + val (df1, df2) = prepareForAsOfJoin() + val join1 = df1.join(df2, df1.col("a") === df2.col("a")).select(df2.col("a"), df1.col("b"), + df2.col("b"), df1.col("a").as("aa")) + + // In stock spark this would throw ambiguous column exception, even though it is not ambiguous + val asOfjoin2 = join1.joinAsOf( + df1, df1.col("a"), join1.col("a"), usingColumns = Seq.empty, + joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest") + + asOfjoin2.queryExecution.assertAnalyzed() + + val testDf = asOfjoin2.select(df1.col("a")) + val analyzed = testDf.queryExecution.analyzed + val attributeRefToCheck = analyzed.output.head + assert(analyzed.children(0).asInstanceOf[AsOfJoin].right.outputSet. + contains(attributeRefToCheck)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index c777d2207584..83d289fcd207 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -488,4 +488,14 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } } + + test("SPARK_47217: Dedup of relations can impact projected columns resolution") { + val df = Seq((1, 2)).toDF("a", "b") + val df2 = df.select(df("a").as("aa"), df("b").as("bb")) + val df3 = df2.join(df, df2("bb") === df("b")).select(df2("aa"), df("a")) + + checkAnswer( + df3, + Row(1, 1) :: Nil) + } } From c29366f489a934eb6aff8b029244fda5cc77daef Mon Sep 17 00:00:00 2001 From: ashahid Date: Thu, 29 Feb 2024 16:52:22 -0800 Subject: [PATCH 02/10] SPARK-47217. fix test failures --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b767cc01f341..7a3eae7144b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1591,7 +1591,8 @@ class Dataset[T] private[sql]( case al: Alias if !al.references.subsetOf(inputSet) && al.nonInheritableMetadataKeys.contains(Dataset.DATASET_ID_KEY) => val unresolvedExpr = al.child.transformUp { - case attr: AttributeReference if !inputSet.contains(attr) => + case attr: AttributeReference if !inputSet.contains(attr) && + attr.metadata.contains(Dataset.DATASET_ID_KEY) => UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) } val newAl = al.copy(child = unresolvedExpr, name = al.name)(exprId = al.exprId, From 31d66c2540423aa852d444c14e88cbfcfaaedc8c Mon Sep 17 00:00:00 2001 From: ashahid Date: Thu, 29 Feb 2024 19:11:17 -0800 Subject: [PATCH 03/10] SPARK-47217. fix style format issue --- .../apache/spark/sql/JavaDataFrameSuite.java | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index f325e3c1a80e..09ccbd508e45 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -571,10 +571,10 @@ public void testDedupBehaviourOnProjection_SPARK_47217() { //Create table 3 DataFrame List table3Data = Arrays.asList(RowFactory.create(1, 2, 3), RowFactory.create(1, 2, 3)); - StructType table3Schema = new StructType(). - add("col31", DataTypes.IntegerType). - add("col32", DataTypes.IntegerType). - add("col33", DataTypes.IntegerType); + StructType table3Schema = new StructType() + .add("col31", DataTypes.IntegerType) + .add("col32", DataTypes.IntegerType) + .add("col33", DataTypes.IntegerType); Dataset table3 = spark.createDataFrame(table3Data, table3Schema); @@ -591,16 +591,15 @@ public void testDedupBehaviourOnProjection_SPARK_47217() { //Perform leftouter join for exchange table2(firstjoin) srcDf = srcDf.join( broadcast(table3), - srcDf.col("col12").equalTo( - table3.col("col32")), "left_outer"). - select(srcDf.col("col11"), srcDf.col("col12"), + srcDf.col("col12").equalTo(table3.col("col32")), + "left_outer").select(srcDf.col("col11"), + srcDf.col("col12"), srcDf.col("col13"), table3.col("col33").as("col33_1")); //Perform left outer joinfor exchangeRateTable1 again(secondjoin) Dataset temp = srcDf.join(broadcast(table3), - srcDf.col("col11").equalTo( - table3.col("col31")), "left_outer"); + srcDf.col("col11").equalTo(table3.col("col31")), "left_outer"); srcDf = temp.select(srcDf.col("col11"), srcDf.col("col12"), srcDf.col("col13"), srcDf.col("col33_1"), From 127016c1ad41b9121639b678c9de9080a0ce9d01 Mon Sep 17 00:00:00 2001 From: ashahid Date: Tue, 5 Mar 2024 21:29:33 -0800 Subject: [PATCH 04/10] SPARK-47217 : Fixing tests and code to try and resolve ambiguity in self join conditions --- .../scala/org/apache/spark/sql/Dataset.scala | 44 ++++++++- .../apache/spark/sql/JavaDataFrameSuite.java | 69 ------------- .../spark/sql/DataFrameAsOfJoinSuite.scala | 2 +- .../spark/sql/DataFrameSelfJoinSuite.scala | 98 ++++++++++++------- 4 files changed, 105 insertions(+), 108 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7a3eae7144b3..d79c2c499bb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1150,10 +1150,18 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. - val plan = withPlan( - Join(logicalPlan, right.logicalPlan, - JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE)) - .queryExecution.analyzed.asInstanceOf[Join] + + val plan = try { + withPlan( + Join(logicalPlan, right.logicalPlan, + JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE)) + .queryExecution.analyzed.asInstanceOf[Join] + } catch { + case ae: AnalysisException if ae.message.contains("ambiguous") => + // attempt to resolve ambiguity + tryAmbiguityResolution(right, joinExprs, joinType) + } + // If auto self join alias is disabled, return the plan. if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) { @@ -1174,6 +1182,34 @@ class Dataset[T] private[sql]( JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, plan) } + private def tryAmbiguityResolution( + right: Dataset[_], + joinExprs: Option[Column], + joinType: String) = { + val planPart1 = withPlan( + Join(logicalPlan, right.logicalPlan, + JoinType(joinType), None, JoinHint.NONE)) + .queryExecution.analyzed.asInstanceOf[Join] + val inputSet = planPart1.outputSet + val joinExprsRectified = joinExprs.map(_.expr transformUp { + case attr: AttributeReference if attr.metadata.contains(Dataset.DATASET_ID_KEY) => + val attribTagId = attr.metadata.getLong(Dataset.DATASET_ID_KEY) + val leftTagIdMap = planPart1.left.getTagValue(LogicalPlan.DATASET_ID_TAG) + val rightTagIdMap = planPart1.right.getTagValue(LogicalPlan.DATASET_ID_TAG) + if (!inputSet.contains(attr) || + (planPart1.left.outputSet.contains(attr) && !leftTagIdMap.contains(attribTagId)) || + (planPart1.right.outputSet.contains(attr) && !rightTagIdMap.contains(attribTagId))) { + UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) + } else { + attr + } + }) + withPlan( + Join(planPart1.left, planPart1.right, + JoinType(joinType), joinExprsRectified, JoinHint.NONE)) + .queryExecution.analyzed.asInstanceOf[Join] + } + /** * Join with another `DataFrame`, using the given join expression. The following performs * a full outer join between `df1` and `df2`. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 09ccbd508e45..94c0a3bd56f1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -546,73 +546,4 @@ public void testUDF() { .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assertions.assertArrayEquals(expected, result); } - - @Test - public void testDedupBehaviourOnProjection_SPARK_47217() { - //Create table1 DF - List table1Data = Arrays.asList( - RowFactory.create(1, 2, 3), - RowFactory.create(1, 2, 3)); - StructType table1Scema = new StructType() - .add("col11", DataTypes.IntegerType) - .add("col12", DataTypes.IntegerType) - .add("col13", DataTypes.IntegerType); - - Dataset table1 = spark.createDataFrame(table1Data, table1Scema); - - //Create table2 DataFrame - List table2Data = Arrays.asList(RowFactory.create(1, 2, 3)); - StructType table2Schema = new StructType() - .add("col21", DataTypes.IntegerType) - .add("col22", DataTypes.IntegerType) - .add("col23", DataTypes.IntegerType); - - Dataset table2 = spark.createDataFrame(table2Data, table2Schema); - - //Create table 3 DataFrame - List table3Data = Arrays.asList(RowFactory.create(1, 2, 3), RowFactory.create(1, 2, 3)); - StructType table3Schema = new StructType() - .add("col31", DataTypes.IntegerType) - .add("col32", DataTypes.IntegerType) - .add("col33", DataTypes.IntegerType); - - Dataset table3 = spark.createDataFrame(table3Data, table3Schema); - - //Perform left outer join for table2 - Dataset srcDf = table1.join( - table2, - table1.col("col11").equalTo(table2.col("col21")), - "left_outer").select( - table1.col("col11"), - table1.col("col12"), - table1.col("col13"), - table2.col("col22")); - - //Perform leftouter join for exchange table2(firstjoin) - srcDf = srcDf.join( - broadcast(table3), - srcDf.col("col12").equalTo(table3.col("col32")), - "left_outer").select(srcDf.col("col11"), - srcDf.col("col12"), - srcDf.col("col13"), - table3.col("col33").as("col33_1")); - - //Perform left outer joinfor exchangeRateTable1 again(secondjoin) - Dataset temp = srcDf.join(broadcast(table3), - srcDf.col("col11").equalTo(table3.col("col31")), "left_outer"); - - srcDf = temp.select(srcDf.col("col11"), srcDf.col("col12"), - srcDf.col("col13"), srcDf.col("col33_1"), - table3.col("col33").as("col33_2")); - - // verify optimized plan creation ok - srcDf.queryExecution().optimizedPlan(); - LogicalPlan lp = srcDf.queryExecution().analyzed(); - // verify attribute ref resolution is correct, i.e it resolves to the right leg of join - AttributeReference refToCheck = - (AttributeReference) ((Alias)((Project)lp).projectList().last()).child(); - AttributeSet compareSet = - ((Join) ((Project) lp).child()).right().outputSet(); - Assertions.assertTrue(compareSet.contains(refToCheck)); - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala index 3db8d374c4bb..ec80c782b5b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala @@ -175,7 +175,7 @@ class DataFrameAsOfJoinSuite extends QueryTest ) } - test("SPARK_47217: Dedup of relations can impact projected columns resolution -1") { + test("SPARK_47217: Dedup of relations can impact projected columns resolution") { val (df1, df2) = prepareForAsOfJoin() val join1 = df1.join(df2, df1.col("a") === df2.col("a")).select(df2.col("a"), df1.col("b"), df2.col("b"), df1.col("a").as("aa")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 83d289fcd207..d14cdc9c4225 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, BinaryExpression, PythonUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, Join, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, explode, sum, year} import org.apache.spark.sql.internal.SQLConf @@ -97,7 +97,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(e.message.contains("ambiguous")) } - test("SPARK-28344: fail ambiguous self join - column ref in join condition") { + test("SPARK-28344: NOT AN ambiguous self join - column ref in join condition") { val df1 = spark.range(3) val df2 = df1.filter($"id" > 0) @@ -118,29 +118,41 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1("id") > df2("id"))) + val df = df1.join(df2, df1("id") > df2("id")) + val join = df.queryExecution.analyzed.asInstanceOf[Join] + val binaryCondition = join.condition.get.asInstanceOf[BinaryExpression] + assert(join.left.outputSet.contains(binaryCondition.left.references.head)) + assert(join.right.outputSet.contains(binaryCondition.right.references.head)) } } - test("SPARK-28344: fail ambiguous self join - Dataset.colRegex as column ref") { + test("SPARK-28344: Not AN ambiguous self join - Dataset.colRegex as column ref") { val df1 = spark.range(3) val df2 = df1.filter($"id" > 0) withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1.colRegex("id") > df2.colRegex("id"))) + val df = df1.join(df2, df1.colRegex("id") > df2.colRegex("id")) + val join = df.queryExecution.analyzed.asInstanceOf[Join] + val binaryCondition = join.condition.get.asInstanceOf[BinaryExpression] + assert(join.left.outputSet.contains(binaryCondition.left.references.head)) + assert(join.right.outputSet.contains(binaryCondition.right.references.head)) } } - test("SPARK-28344: fail ambiguous self join - Dataset.col with nested field") { + test("SPARK-28344: Not An ambiguous self join - Dataset.col with nested field") { val df1 = spark.read.json(Seq("""{"a": {"b": 1, "c": 1}}""").toDS()) val df2 = df1.filter($"a.b" > 0) withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1("a.b") > df2("a.c"))) + val df = df1.join(df2, df1("a.b") > df2("a.c")) + val join = df.queryExecution.analyzed.asInstanceOf[Join] + val binaryCondition = join.condition.get.asInstanceOf[BinaryExpression] + assert(join.left.outputSet.contains(binaryCondition.left.references.head)) + assert(join.right.outputSet.contains(binaryCondition.right.references.head)) } } @@ -293,14 +305,14 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(col1DsId !== col2DsId) } - test("SPARK-35454: fail ambiguous self join - toDF") { + test("SPARK-35454: Not an ambiguous self join - toDF") { val df1 = spark.range(3).toDF() val df2 = df1.filter($"id" > 0).toDF() withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1.col("id") > df2.col("id"))) + df1.join(df2, df1.col("id") > df2.col("id")) } } @@ -353,20 +365,20 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { // Test for Project val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") val df2 = df1.filter($"value" === "A2") - assertAmbiguousSelfJoin(df1.join(df2, df1("key1") === df2("key2"))) - assertAmbiguousSelfJoin(df2.join(df1, df1("key1") === df2("key2"))) + df1.join(df2, df1("key1") === df2("key2")) + df2.join(df1, df1("key1") === df2("key2")) // Test for SerializeFromObject val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF() val df4 = df3.filter($"_1" <=> 0) - assertAmbiguousSelfJoin(df3.join(df4, df3("_1") === df4("_2"))) - assertAmbiguousSelfJoin(df4.join(df3, df3("_1") === df4("_2"))) + df3.join(df4, df3("_1") === df4("_2")) + df4.join(df3, df3("_1") === df4("_2")) // Test For Aggregate val df5 = df1.groupBy($"key1").agg(count($"value") as "count") val df6 = df5.filter($"key1" > 0) - assertAmbiguousSelfJoin(df5.join(df6, df5("key1") === df6("count"))) - assertAmbiguousSelfJoin(df6.join(df5, df5("key1") === df6("count"))) + df5.join(df6, df5("key1") === df6("count")) + df6.join(df5, df5("key1") === df6("count")) // Test for MapInPandas val mapInPandasUDF = PythonUDF("mapInPandasUDF", null, @@ -376,8 +388,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { true) val df7 = df1.mapInPandas(mapInPandasUDF) val df8 = df7.filter($"x" > 0) - assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y"))) - assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y"))) + df7.join(df8, df7("x") === df8("y")) + df8.join(df7, df7("x") === df8("y")) // Test for FlatMapGroupsInPandas val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null, @@ -387,8 +399,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { true) val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF) val df10 = df9.filter($"x" > 0) - assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y"))) - assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y"))) + df9.join(df10, df9("x") === df10("y")) + df10.join(df9, df9("x") === df10("y")) // Test for FlatMapCoGroupsInPandas val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null, @@ -399,22 +411,22 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas( df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF) val df12 = df11.filter($"x" > 0) - assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y"))) - assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y"))) + df11.join(df12, df11("x") === df12("y")) + df12.join(df11, df11("x") === df12("y")) // Test for AttachDistributedSequence val df13 = df1.withSequenceColumn("seq") val df14 = df13.filter($"value" === "A2") - assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2"))) - assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2"))) + df13.join(df14, df13("key1") === df14("key2")) + df14.join(df13, df13("key1") === df14("key2")) // Test for Generate // Ensure that the root of the plan is Generate val df15 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList")) .queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF() val df16 = df15.filter($"a" > 0) - assertAmbiguousSelfJoin(df15.join(df16, df15("a") === df16("col"))) - assertAmbiguousSelfJoin(df16.join(df15, df15("a") === df16("col"))) + df15.join(df16, df15("a") === df16("col")) + df16.join(df15, df15("a") === df16("col")) // Test for Expand // Ensure that the root of the plan is Expand @@ -426,8 +438,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { AttributeReference("y", IntegerType)()), df1.queryExecution.logical).toDF() val df18 = df17.filter($"x" > 0) - assertAmbiguousSelfJoin(df17.join(df18, df17("x") === df18("y"))) - assertAmbiguousSelfJoin(df18.join(df17, df17("x") === df18("y"))) + df17.join(df18, df17("x") === df18("y")) + df18.join(df17, df17("x") === df18("y")) // Test for Window val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b") @@ -438,8 +450,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { Seq(SortOrder(dfWithTS("a").expr, Ascending)), dfWithTS.queryExecution.logical).toDF() val df20 = df19.filter($"a" > 0) - assertAmbiguousSelfJoin(df19.join(df20, df19("a") === df20("b"))) - assertAmbiguousSelfJoin(df20.join(df19, df19("a") === df20("b"))) + df19.join(df20, df19("a") === df20("b")) + df20.join(df19, df19("a") === df20("b")) // Test for ScriptTransformation val ioSchema = @@ -464,8 +476,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { df1.queryExecution.logical, ioSchema).toDF() val df22 = df21.filter($"x" > 0) - assertAmbiguousSelfJoin(df21.join(df22, df21("x") === df22("y"))) - assertAmbiguousSelfJoin(df22.join(df21, df21("x") === df22("y"))) + df21.join(df22, df21("x") === df22("y")) + df22.join(df21, df21("x") === df22("y")) } test("SPARK-35937: GetDateFieldOperations should skip unresolved nodes") { @@ -489,13 +501,31 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { } } - test("SPARK_47217: Dedup of relations can impact projected columns resolution") { + test("SPARK_47217: deduplication of project causes ambiguity in resolution") { val df = Seq((1, 2)).toDF("a", "b") val df2 = df.select(df("a").as("aa"), df("b").as("bb")) val df3 = df2.join(df, df2("bb") === df("b")).select(df2("aa"), df("a")) - checkAnswer( df3, Row(1, 1) :: Nil) } + + test("SPARK-47217. deduplication in nested joins focussing on projection") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((1, 2)).toDF("aa", "bb") + val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a").as("aaa"), + df2("aa"), df1("b")) + val df3 = df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")). + select(df1Joindf2("aa"), df1("a")) + df3.queryExecution.assertAnalyzed() + } + + test("SPARK-47217. deduplication in nested joins focusing on condition") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((1, 2)).toDF("aa", "bb") + val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a"), + df2("aa"), df1("b")) + val df3 = df1Joindf2.join(df1, df1Joindf2("aa") === df1("a")) + df3.queryExecution.assertAnalyzed() + } } From b8e369c239f17cba34df2c816a6f5b412c5aa63b Mon Sep 17 00:00:00 2001 From: ashahid Date: Tue, 5 Mar 2024 23:50:18 -0800 Subject: [PATCH 05/10] SPARK-47217 : Fix unused import issue --- .../java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 94c0a3bd56f1..26a19cbed1b9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -24,7 +24,6 @@ import java.math.BigInteger; import java.math.BigDecimal; -import org.apache.spark.sql.catalyst.plans.logical.Join; import scala.collection.Seq; import scala.jdk.javaapi.CollectionConverters; @@ -32,11 +31,6 @@ import com.google.common.primitives.Ints; import org.junit.jupiter.api.*; -import org.apache.spark.sql.catalyst.expressions.Alias; -import org.apache.spark.sql.catalyst.expressions.AttributeReference; -import org.apache.spark.sql.catalyst.expressions.AttributeSet; -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; -import org.apache.spark.sql.catalyst.plans.logical.Project; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Column; From 872fece15d69b27a1bd40c7ea1e99ce2a604788c Mon Sep 17 00:00:00 2001 From: ashahid Date: Wed, 6 Mar 2024 14:36:19 -0800 Subject: [PATCH 06/10] SPARK-47217 : fixed bug and made assertions in existing tests for correct resolution of attributes --- .../analysis/ColumnResolutionHelper.scala | 19 +-- .../spark/sql/DataFrameSelfJoinSuite.scala | 148 +++++++++++------- 2 files changed, 99 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 22361b0e73c7..bd1e561dfe06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -485,26 +485,15 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { resolveOnDatasetId = (datasetid: Long, name: String) => { def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[LogicalPlan] = { - var currentLp = lp - while(currentLp.children.size < 2) { - if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetid))) { - return Option(currentLp) - } else { - if (currentLp.children.size == 1) { - currentLp = currentLp.children.head - } else { - // leaf node - return None - } - } + if (lp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetid))) { + Option(lp) + } else { + None } - None } - val binaryNodeOpt = q.collectFirst { case bn: BinaryNode => bn } - val resolveOnAttribs = binaryNodeOpt match { case Some(bn) => val leftDefOpt = findUnaryNodeMatchingTagId(bn.left) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index d14cdc9c4225..c962f31eefac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -97,6 +97,27 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(e.message.contains("ambiguous")) } + private def assertCorrectResolution( + df: => DataFrame, + leftResolution: Resolution.Resolution, + rightResolution: Resolution.Resolution): Unit = { + val join = df.queryExecution.analyzed.asInstanceOf[Join] + val binaryCondition = join.condition.get.asInstanceOf[BinaryExpression] + leftResolution match { + case Resolution.LeftConditionToLeftLeg => + assert(join.left.outputSet.contains(binaryCondition.left.references.head)) + case Resolution.LeftConditionToRightLeg => + assert(join.right.outputSet.contains(binaryCondition.left.references.head)) + } + + rightResolution match { + case Resolution.RightConditionToLeftLeg => + assert(join.left.outputSet.contains(binaryCondition.right.references.head)) + case Resolution.RightConditionToRightLeg => + assert(join.right.outputSet.contains(binaryCondition.right.references.head)) + } + } + test("SPARK-28344: NOT AN ambiguous self join - column ref in join condition") { val df1 = spark.range(3) val df2 = df1.filter($"id" > 0) @@ -118,11 +139,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - val df = df1.join(df2, df1("id") > df2("id")) - val join = df.queryExecution.analyzed.asInstanceOf[Join] - val binaryCondition = join.condition.get.asInstanceOf[BinaryExpression] - assert(join.left.outputSet.contains(binaryCondition.left.references.head)) - assert(join.right.outputSet.contains(binaryCondition.right.references.head)) + assertCorrectResolution(df1.join(df2, df1("id") > df2("id")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) } } @@ -133,11 +151,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - val df = df1.join(df2, df1.colRegex("id") > df2.colRegex("id")) - val join = df.queryExecution.analyzed.asInstanceOf[Join] - val binaryCondition = join.condition.get.asInstanceOf[BinaryExpression] - assert(join.left.outputSet.contains(binaryCondition.left.references.head)) - assert(join.right.outputSet.contains(binaryCondition.right.references.head)) + assertCorrectResolution(df1.join(df2, df1.colRegex("id") > df2.colRegex("id")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) } } @@ -148,11 +163,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - val df = df1.join(df2, df1("a.b") > df2("a.c")) - val join = df.queryExecution.analyzed.asInstanceOf[Join] - val binaryCondition = join.condition.get.asInstanceOf[BinaryExpression] - assert(join.left.outputSet.contains(binaryCondition.left.references.head)) - assert(join.right.outputSet.contains(binaryCondition.right.references.head)) + assertCorrectResolution( df1.join(df2, df1("a.b") > df2("a.c")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) } } @@ -312,7 +324,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - df1.join(df2, df1.col("id") > df2.col("id")) + assertCorrectResolution(df1.join(df2, df1.col("id") > df2.col("id")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) } } @@ -363,22 +376,30 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { test("SPARK-36874: DeduplicateRelations should copy dataset_id tag " + "to avoid ambiguous self join") { // Test for Project + val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") val df2 = df1.filter($"value" === "A2") - df1.join(df2, df1("key1") === df2("key2")) - df2.join(df1, df1("key1") === df2("key2")) + /* assertCorrectResolution(df1.join(df2, df1("key1") === df2("key2")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df2.join(df1, df1("key1") === df2("key2")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + // Test for SerializeFromObject val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF() val df4 = df3.filter($"_1" <=> 0) - df3.join(df4, df3("_1") === df4("_2")) - df4.join(df3, df3("_1") === df4("_2")) + assertCorrectResolution(df3.join(df4, df3("_1") === df4("_2")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df4.join(df3, df3("_1") === df4("_2")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test For Aggregate val df5 = df1.groupBy($"key1").agg(count($"value") as "count") val df6 = df5.filter($"key1" > 0) - df5.join(df6, df5("key1") === df6("count")) - df6.join(df5, df5("key1") === df6("count")) + assertCorrectResolution(df5.join(df6, df5("key1") === df6("count")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df6.join(df5, df5("key1") === df6("count")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for MapInPandas val mapInPandasUDF = PythonUDF("mapInPandasUDF", null, @@ -388,8 +409,10 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { true) val df7 = df1.mapInPandas(mapInPandasUDF) val df8 = df7.filter($"x" > 0) - df7.join(df8, df7("x") === df8("y")) - df8.join(df7, df7("x") === df8("y")) + assertCorrectResolution(df7.join(df8, df7("x") === df8("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df8.join(df7, df7("x") === df8("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for FlatMapGroupsInPandas val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null, @@ -399,9 +422,11 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { true) val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF) val df10 = df9.filter($"x" > 0) - df9.join(df10, df9("x") === df10("y")) - df10.join(df9, df9("x") === df10("y")) - + assertCorrectResolution(df9.join(df10, df9("x") === df10("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df10.join(df9, df9("x") === df10("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + */ // Test for FlatMapCoGroupsInPandas val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null, StructType(Seq(StructField("x", LongType), StructField("y", LongType))), @@ -411,22 +436,27 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas( df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF) val df12 = df11.filter($"x" > 0) - df11.join(df12, df11("x") === df12("y")) - df12.join(df11, df11("x") === df12("y")) + /* assertCorrectResolution(df11.join(df12, df11("x") === df12("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) */ + assertCorrectResolution(df12.join(df11, df11("x") === df12("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for AttachDistributedSequence val df13 = df1.withSequenceColumn("seq") val df14 = df13.filter($"value" === "A2") - df13.join(df14, df13("key1") === df14("key2")) - df14.join(df13, df13("key1") === df14("key2")) - + assertCorrectResolution(df13.join(df14, df13("key1") === df14("key2")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df14.join(df13, df13("key1") === df14("key2")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for Generate // Ensure that the root of the plan is Generate val df15 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList")) .queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF() val df16 = df15.filter($"a" > 0) - df15.join(df16, df15("a") === df16("col")) - df16.join(df15, df15("a") === df16("col")) + assertCorrectResolution(df15.join(df16, df15("a") === df16("col")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df16.join(df15, df15("a") === df16("col")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for Expand // Ensure that the root of the plan is Expand @@ -438,9 +468,10 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { AttributeReference("y", IntegerType)()), df1.queryExecution.logical).toDF() val df18 = df17.filter($"x" > 0) - df17.join(df18, df17("x") === df18("y")) - df18.join(df17, df17("x") === df18("y")) - + assertCorrectResolution(df17.join(df18, df17("x") === df18("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df18.join(df17, df17("x") === df18("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for Window val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b") // Ensure that the root of the plan is Window @@ -450,9 +481,10 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { Seq(SortOrder(dfWithTS("a").expr, Ascending)), dfWithTS.queryExecution.logical).toDF() val df20 = df19.filter($"a" > 0) - df19.join(df20, df19("a") === df20("b")) - df20.join(df19, df19("a") === df20("b")) - + assertCorrectResolution(df19.join(df20, df19("a") === df20("b")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df20.join(df19, df19("a") === df20("b")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for ScriptTransformation val ioSchema = ScriptInputOutputSchema( @@ -476,8 +508,10 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { df1.queryExecution.logical, ioSchema).toDF() val df22 = df21.filter($"x" > 0) - df21.join(df22, df21("x") === df22("y")) - df22.join(df21, df21("x") === df22("y")) + assertCorrectResolution(df21.join(df22, df21("x") === df22("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df22.join(df21, df21("x") === df22("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) } test("SPARK-35937: GetDateFieldOperations should skip unresolved nodes") { @@ -515,17 +549,25 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { val df2 = Seq((1, 2)).toDF("aa", "bb") val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a").as("aaa"), df2("aa"), df1("b")) - val df3 = df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")). - select(df1Joindf2("aa"), df1("a")) - df3.queryExecution.assertAnalyzed() - } - test("SPARK-47217. deduplication in nested joins focusing on condition") { - val df1 = Seq((1, 2)).toDF("a", "b") - val df2 = Seq((1, 2)).toDF("aa", "bb") - val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a"), - df2("aa"), df1("b")) - val df3 = df1Joindf2.join(df1, df1Joindf2("aa") === df1("a")) - df3.queryExecution.assertAnalyzed() + assertCorrectResolution(df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + assertCorrectResolution(df1.join(df1Joindf2, df1Joindf2("aaa") === df1("a")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), df1("a")). + queryExecution.analyzed + + df1.join(df1Joindf2, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), df1("a")). + queryExecution.analyzed } } + +object Resolution extends Enumeration { + type Resolution = Value + + val LeftConditionToLeftLeg, LeftConditionToRightLeg, RightConditionToRightLeg, + RightConditionToLeftLeg = Value +} + From 8ed6aa48502dd4ca820671b58a7136cb13fbdcb7 Mon Sep 17 00:00:00 2001 From: ashahid Date: Wed, 6 Mar 2024 17:30:13 -0800 Subject: [PATCH 07/10] SPARK-47217 : added more assetions --- .../spark/sql/DataFrameSelfJoinSuite.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index e0388f12cee6..b6bf214e8407 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, BinaryExpression, PythonUDF, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, Join, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, Join, Project, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, explode, sum, year} import org.apache.spark.sql.internal.SQLConf @@ -566,11 +566,17 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assertCorrectResolution(df1.join(df1Joindf2, df1Joindf2("aaa") === df1("a")), Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) - df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), df1("a")). - queryExecution.analyzed - - df1.join(df1Joindf2, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), df1("a")). - queryExecution.analyzed + val proj1 = df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(1).references.subsetOf(join1.right.outputSet)) + + val proj2 = df1.join(df1Joindf2, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join2 = proj2.child.asInstanceOf[Join] + assert(proj2.projectList(0).references.subsetOf(join2.right.outputSet)) + assert(proj2.projectList(1).references.subsetOf(join2.left.outputSet)) } } From 6b3b1d4ed0549a627df837631118d6058ec5f91a Mon Sep 17 00:00:00 2001 From: ashahid Date: Wed, 6 Mar 2024 21:24:26 -0800 Subject: [PATCH 08/10] SPARK-47217 : fixed a bug and uncommented tests which were inadvertently commented --- .../analysis/ColumnResolutionHelper.scala | 39 +++++++++++++++---- .../spark/sql/DataFrameSelfJoinSuite.scala | 10 ++--- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index bd1e561dfe06..4aac2c6c7067 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -484,23 +484,48 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { }, resolveOnDatasetId = (datasetid: Long, name: String) => { - def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[LogicalPlan] = { - if (lp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetid))) { - Option(lp) - } else { - None + def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[(LogicalPlan, Int)] = { + var currentLp = lp + var depth = 0 + while(true) { + if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetid))) { + return Option(currentLp, depth) + } else { + if (currentLp.children.size == 1) { + currentLp = currentLp.children.head + } else { + // leaf node or node is a binary node + return None + } + } + depth += 1 } + None } + val binaryNodeOpt = q.collectFirst { case bn: BinaryNode => bn } + val resolveOnAttribs = binaryNodeOpt match { case Some(bn) => val leftDefOpt = findUnaryNodeMatchingTagId(bn.left) val rightDefOpt = findUnaryNodeMatchingTagId(bn.right) (leftDefOpt, rightDefOpt) match { - case (None, Some(lp)) => lp.output - case (Some(lp), None) => lp.output + + case (None, Some((lp, _))) => lp.output + + case (Some((lp, _)), None) => lp.output + + case (Some((lp1, depth1)), Some((lp2, depth2))) => + if (depth1 == depth2) { + q.children.head.output + } else if (depth1 < depth2) { + lp1.output + } else { + lp2.output + } + case _ => q.children.head.output } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index b6bf214e8407..9aa7707f8361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -379,7 +379,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") val df2 = df1.filter($"value" === "A2") - /* assertCorrectResolution(df1.join(df2, df1("key1") === df2("key2")), + assertCorrectResolution(df1.join(df2, df1("key1") === df2("key2")), Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) assertCorrectResolution(df2.join(df1, df1("key1") === df2("key2")), Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) @@ -426,7 +426,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) assertCorrectResolution(df10.join(df9, df9("x") === df10("y")), Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) - */ + // Test for FlatMapCoGroupsInPandas val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null, StructType(Seq(StructField("x", LongType), StructField("y", LongType))), @@ -436,8 +436,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas( df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF) val df12 = df11.filter($"x" > 0) - /* assertCorrectResolution(df11.join(df12, df11("x") === df12("y")), - Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) */ + assertCorrectResolution(df11.join(df12, df11("x") === df12("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) assertCorrectResolution(df12.join(df11, df11("x") === df12("y")), Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) @@ -584,6 +584,6 @@ object Resolution extends Enumeration { type Resolution = Value val LeftConditionToLeftLeg, LeftConditionToRightLeg, RightConditionToRightLeg, - RightConditionToLeftLeg = Value + RightConditionToLeftLeg = Value } From f9653ec7d439776a8b3c384fe1bebf045e7ad9ca Mon Sep 17 00:00:00 2001 From: ashahid Date: Thu, 7 Mar 2024 15:20:25 -0800 Subject: [PATCH 09/10] SPARK-47217 : added more tests and fixed inconsistency --- .../scala/org/apache/spark/sql/Dataset.scala | 54 +++++++++++-------- .../spark/sql/DataFrameSelfJoinSuite.scala | 49 +++++++++++++++-- 2 files changed, 77 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d79c2c499bb2..2bd48e45dd84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -38,6 +38,7 @@ import org.apache.spark.api.r.RRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.resource.ResourceProfile +import org.apache.spark.sql.Dataset.DATASET_ID_KEY import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation @@ -1151,17 +1152,9 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. - val plan = try { - withPlan( - Join(logicalPlan, right.logicalPlan, - JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE)) - .queryExecution.analyzed.asInstanceOf[Join] - } catch { - case ae: AnalysisException if ae.message.contains("ambiguous") => - // attempt to resolve ambiguity - tryAmbiguityResolution(right, joinExprs, joinType) - } - + val plan = withPlan( + tryAmbiguityResolution(right, joinExprs, joinType) + ).queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) { @@ -1199,15 +1192,13 @@ class Dataset[T] private[sql]( if (!inputSet.contains(attr) || (planPart1.left.outputSet.contains(attr) && !leftTagIdMap.contains(attribTagId)) || (planPart1.right.outputSet.contains(attr) && !rightTagIdMap.contains(attribTagId))) { - UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) + UnresolvedAttributeWithTag(attr, attribTagId) } else { attr } }) - withPlan( - Join(planPart1.left, planPart1.right, - JoinType(joinType), joinExprsRectified, JoinHint.NONE)) - .queryExecution.analyzed.asInstanceOf[Join] + + Join(planPart1.left, planPart1.right, JoinType(joinType), joinExprsRectified, JoinHint.NONE) } /** @@ -1624,11 +1615,25 @@ class Dataset[T] private[sql]( val inputSet = logicalPlan.outputSet val rectifiedNamedExprs = namedExprs.map(ne => ne match { - case al: Alias if !al.references.subsetOf(inputSet) && + case al: Alias if (!al.references.subsetOf(inputSet) || al.references.exists(attr => + attr.metadata.contains(DATASET_ID_KEY) && attr.metadata.getLong(DATASET_ID_KEY) != + inputSet.find(_.canonicalized == attr.canonicalized).map(x => + if (x.metadata.contains(DATASET_ID_KEY)) { + x.metadata.getLong(DATASET_ID_KEY) + } else { + -1 + }).get)) && al.nonInheritableMetadataKeys.contains(Dataset.DATASET_ID_KEY) => val unresolvedExpr = al.child.transformUp { - case attr: AttributeReference if !inputSet.contains(attr) && - attr.metadata.contains(Dataset.DATASET_ID_KEY) => + case attr: AttributeReference if attr.metadata.contains(Dataset.DATASET_ID_KEY) && + (!inputSet.contains(attr) || attr.metadata.getLong(DATASET_ID_KEY) != + inputSet.find(_.canonicalized == attr.canonicalized).map(x => + if (x.metadata.contains(DATASET_ID_KEY)) { + x.metadata.getLong(DATASET_ID_KEY) + } else { + -1 + }).get) + => UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) } val newAl = al.copy(child = unresolvedExpr, name = al.name)(exprId = al.exprId, @@ -1637,8 +1642,15 @@ class Dataset[T] private[sql]( newAl.copyTagsFrom(al) newAl - case attr: Attribute if !inputSet.contains(attr) && - attr.metadata.contains(Dataset.DATASET_ID_KEY) => + case attr: Attribute if attr.metadata.contains(Dataset.DATASET_ID_KEY) && + (!inputSet.contains(attr) || attr.metadata.getLong(DATASET_ID_KEY) != + inputSet.find(_.canonicalized == attr.canonicalized).map(x => + if (x.metadata.contains(DATASET_ID_KEY)) { + x.metadata.getLong(DATASET_ID_KEY) + } else { + -1 + }).get) + => UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) case _ => ne diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 9aa7707f8361..4fccf9d2415c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -189,7 +189,9 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2).select(df2("id"))) + val proj1 = df1.join(df2).select(df2("id")).queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.right.outputSet)) } } @@ -229,7 +231,11 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { assertAmbiguousSelfJoin(df1.join(df2).join(df3, df2("id") < df3("id"))) - assertAmbiguousSelfJoin(df1.join(df4).join(df2).select(df2("id"))) + + val proj1 = df1.join(df4).join(df2).select(df2("id")).queryExecution.analyzed. + asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.right.outputSet)) } } @@ -261,8 +267,17 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { TestData(2, "personnel"), TestData(3, "develop")).toDS() val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*")) - assertAmbiguousSelfJoin(emp1.join(emp3, emp1.col("key") === emp3.col("key"), - "left_outer").select(emp1.col("*"), emp3.col("key").as("e2"))) + + assertCorrectResolution(emp1.join(emp3, emp1.col("key") === emp3.col("key")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + val proj1 = emp1.join(emp3, emp1.col("key") === emp3.col("key"), + "left_outer").select(emp1.col("*"), emp3.col("key").as("e2")). + queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(1).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(2).references.subsetOf(join1.right.outputSet)) } test("df.show() should also not change dataset_id of LogicalPlan") { @@ -554,7 +569,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { Row(1, 1) :: Nil) } - test("SPARK-47217. deduplication in nested joins") { + test("SPARK-47217. deduplication in nested joins with join attribute aliased") { val df1 = Seq((1, 2)).toDF("a", "b") val df2 = Seq((1, 2)).toDF("aa", "bb") val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a").as("aaa"), @@ -578,6 +593,30 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(proj2.projectList(0).references.subsetOf(join2.right.outputSet)) assert(proj2.projectList(1).references.subsetOf(join2.left.outputSet)) } + + test("SPARK-47217. deduplication in nested joins without join attribute aliased") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((1, 2)).toDF("aa", "bb") + val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a"), df2("aa"), df1("b")) + + assertCorrectResolution(df1Joindf2.join(df1, df1Joindf2("a") === df1("a")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + assertCorrectResolution(df1.join(df1Joindf2, df1Joindf2("a") === df1("a")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + val proj1 = df1Joindf2.join(df1, df1Joindf2("a") === df1("a")).select(df1Joindf2("a"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(1).references.subsetOf(join1.right.outputSet)) + + val proj2 = df1.join(df1Joindf2, df1Joindf2("a") === df1("a")).select(df1Joindf2("a"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join2 = proj2.child.asInstanceOf[Join] + assert(proj2.projectList(0).references.subsetOf(join2.right.outputSet)) + assert(proj2.projectList(1).references.subsetOf(join2.left.outputSet)) + } } object Resolution extends Enumeration { From 7150c9887acd217bdae3c9e58d58d37e070ab0e1 Mon Sep 17 00:00:00 2001 From: ashahid Date: Fri, 8 Mar 2024 11:49:22 -0800 Subject: [PATCH 10/10] SPARK-47217 : fixed test failure --- .../main/scala/org/apache/spark/sql/Dataset.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2bd48e45dd84..0f15fcf51b8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1615,15 +1615,15 @@ class Dataset[T] private[sql]( val inputSet = logicalPlan.outputSet val rectifiedNamedExprs = namedExprs.map(ne => ne match { - case al: Alias if (!al.references.subsetOf(inputSet) || al.references.exists(attr => + case al: Alias if !al.references.subsetOf(inputSet) || al.references.exists(attr => attr.metadata.contains(DATASET_ID_KEY) && attr.metadata.getLong(DATASET_ID_KEY) != inputSet.find(_.canonicalized == attr.canonicalized).map(x => if (x.metadata.contains(DATASET_ID_KEY)) { x.metadata.getLong(DATASET_ID_KEY) } else { - -1 - }).get)) && - al.nonInheritableMetadataKeys.contains(Dataset.DATASET_ID_KEY) => + Dataset.this.id + }).get) + => val unresolvedExpr = al.child.transformUp { case attr: AttributeReference if attr.metadata.contains(Dataset.DATASET_ID_KEY) && (!inputSet.contains(attr) || attr.metadata.getLong(DATASET_ID_KEY) != @@ -1631,7 +1631,7 @@ class Dataset[T] private[sql]( if (x.metadata.contains(DATASET_ID_KEY)) { x.metadata.getLong(DATASET_ID_KEY) } else { - -1 + Dataset.this.id }).get) => UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) @@ -1648,7 +1648,7 @@ class Dataset[T] private[sql]( if (x.metadata.contains(DATASET_ID_KEY)) { x.metadata.getLong(DATASET_ID_KEY) } else { - -1 + Dataset.this.id }).get) => UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY))