diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 95ee69d2a47d..a0729adb8960 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -940,11 +940,12 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM } assert(e3.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) - val e4 = intercept[AnalysisException] { - // df1("i") is ambiguous as df1 appears in both join sides (df1_filter contains df1). - df1.join(df1_filter, df1("i") === 1).collect() - } - assert(e4.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + // TODO(SPARK-47749): Dataframe.collect should accept duplicated column names + assert( + // df1.join(df1_filter, df1("i") === 1) fails in classic spark due to: + // org.apache.spark.sql.AnalysisException: Column i#24 are ambiguous + df1.join(df1_filter, df1("i") === 1).columns === + Array("i", "j", "i", "j")) checkSameResult( Seq(Row("a")), diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 3b8e8165b4bf..16e9a577451f 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1155,6 +1155,15 @@ def test_crossjoin(self): set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()), ) + def test_self_join(self): + # SPARK-47713: this query fails in classic spark + df1 = self.connect.createDataFrame([(1, "a")], schema=["i", "j"]) + df1_filter = df1.filter(df1.i > 0) + df2 = df1.join(df1_filter, df1.i == 1) + self.assertEqual(df2.count(), 1) + self.assertEqual(df2.columns, ["i", "j", "i", "j"]) + self.assertEqual(list(df2.first()), [1, "a", 1, "a"]) + def test_with_metadata(self): cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], schema=["age", "name"]) self.assertEqual(cdf.schema["age"].metadata, {}) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index dc7d39155345..1eccb40e709c 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -123,6 +123,13 @@ def test_self_join(self): df = df2.join(df1, df2["b"] == df1["a"]) self.assertTrue(df.count() == 100) + def test_self_join_II(self): + df = self.spark.createDataFrame([(1, 2), (3, 4)], schema=["a", "b"]) + df2 = df.select(df.a.alias("aa"), df.b) + df3 = df2.join(df, df2.b == df.b) + self.assertTrue(df3.columns, ["aa", "b", "a", "b"]) + self.assertTrue(df3.count() == 2) + def test_duplicated_column_names(self): df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) row = df.select("*").first() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 8ea50e2ceb65..6e27192ead32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -527,7 +527,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { logDebug(s"Extract plan_id $planId from $u") val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty - val (resolved, matched) = resolveDataFrameColumnByPlanId(u, planId, isMetadataAccess, q) + + val (resolved, matched) = resolveDataFrameColumnByPlanId( + u, planId, isMetadataAccess, q, 0) if (!matched) { // Can not find the target plan node with plan id, e.g. // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) @@ -535,29 +537,35 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { // df1.select(df2.a) <- illegal reference df2.a throw QueryCompilationErrors.cannotResolveDataFrameColumn(u) } - resolved + resolved.map(_._1) } private def resolveDataFrameColumnByPlanId( u: UnresolvedAttribute, id: Long, isMetadataAccess: Boolean, - q: Seq[LogicalPlan]): (Option[NamedExpression], Boolean) = { - q.iterator.map(resolveDataFrameColumnRecursively(u, id, isMetadataAccess, _)) - .foldLeft((Option.empty[NamedExpression], false)) { - case ((r1, m1), (r2, m2)) => - if (r1.nonEmpty && r2.nonEmpty) { - throw QueryCompilationErrors.ambiguousColumnReferences(u) - } - (if (r1.nonEmpty) r1 else r2, m1 | m2) + q: Seq[LogicalPlan], + currentDepth: Int): (Option[(NamedExpression, Int)], Boolean) = { + val resolved = q.map(resolveDataFrameColumnRecursively( + u, id, isMetadataAccess, _, currentDepth)) + val merged = resolved + .flatMap(_._1) + .sortBy(_._2) // sort by depth + .foldLeft(Option.empty[(NamedExpression, Int)]) { + case (None, (r2, d2)) => Some((r2, d2)) + case (Some((r1, 0)), (r2, d2)) if d2 != 0 => Some((r1, 0)) + case _ => throw QueryCompilationErrors.ambiguousColumnReferences(u) } + val matched = resolved.exists(_._2) + (merged, matched) } private def resolveDataFrameColumnRecursively( u: UnresolvedAttribute, id: Long, isMetadataAccess: Boolean, - p: LogicalPlan): (Option[NamedExpression], Boolean) = { + p: LogicalPlan, + currentDepth: Int): (Option[(NamedExpression, Int)], Boolean) = { val (resolved, matched) = if (p.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { val resolved = try { if (!isMetadataAccess) { @@ -572,9 +580,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { logDebug(s"Fail to resolve $u with $p due to $e") None } - (resolved, true) + (resolved.map(r => (r, currentDepth)), true) } else { - resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, p.children) + resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, p.children, currentDepth + 1) } // In self join case like: @@ -604,9 +612,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { // will try to resolve it without plan id later. val filtered = resolved.filter { r => if (isMetadataAccess) { - r.references.subsetOf(AttributeSet(p.output ++ p.metadataOutput)) + r._1.references.subsetOf(AttributeSet(p.output ++ p.metadataOutput)) } else { - r.references.subsetOf(p.outputSet) + r._1.references.subsetOf(p.outputSet) } } (filtered, matched)