diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 899ee67352df..cadf3c5fa779 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -390,7 +390,12 @@ class Analyzer( case a: Attribute => attributeRewrites.get(a).getOrElse(a) } } - j.copy(right = newRight) + val newCondition = j.condition.map ( _.transform { + case a: AttributeReference if a.resolved && a.qualifiers.head == "RIGHT_TREE" => + attributeRewrites.get(a).getOrElse(a).withQualifiers(Nil) + case o => o + }) + j.copy(right = newRight, condition = newCondition ) } // When resolve `SortOrder`s in Sort based on child, don't report errors as diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9ab5c299d0f5..18f576fc0d50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -194,7 +194,8 @@ case class AttributeReference( def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId override def equals(other: Any): Boolean = other match { - case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => name == ar.name && exprId == ar.exprId && + dataType == ar.dataType && qualifiers == ar.qualifiers case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f2d4db555027..bcadcf1dec7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -567,6 +567,16 @@ class DataFrame private[sql]( * @since 1.3.0 */ def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + // Note that ... + val newJoinExprs = joinExprs.expr.transform { + case arLeft: AttributeReference + if arLeft.qualifiers.head == this.hashCode.toString => + arLeft.withQualifiers("LEFT_TREE" :: Nil) + case arRight: AttributeReference + if arRight.qualifiers.head == right.hashCode.toString => + arRight.withQualifiers("RIGHT_TREE" :: Nil) + case o => o + } // Note that in this function, we introduce a hack in the case of self-join to automatically // resolve ambiguous join conditions into ones that might make sense [SPARK-6231]. // Consider this case: df.join(df, df("key") === df("key")) @@ -578,7 +588,7 @@ class DataFrame private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(newJoinExprs))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -700,7 +710,11 @@ class DataFrame private[sql]( case "*" => Column(ResolvedStar(queryExecution.analyzed.output)) case _ => - val expr = resolve(colName) + val expr = resolve(colName) match { + case ar: AttributeReference => + ar.withQualifiers(this.hashCode.toString :: Nil) + case o => o + } Column(expr) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 56ad71ea4f48..50e37700c57e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -106,6 +106,40 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } + test("[SPARK-10838] self join - conflicting attributes in condition - incorrect result 1") { + val df1 = Seq((1, 3), (2, 1)).toDF("keyCol1", "keyCol2") + val df2 = Seq((1, 4), (2, 1)).toDF("keyCol1", "keyCol3") + + val df3 = df1.join(df2, df1("keyCol1") === df2("keyCol1")).select(df1("keyCol1")) + + checkAnswer( + df3.join(df1, df1("keyCol2") === df3("keyCol1")), + Row(1, 2, 1) :: Nil) + } + + test("[SPARK-10838] self join - conflicting attributes in condition - incorrect result 2") { + val df1 = Seq((1, 3), (2, 1)).toDF("keyCol1", "keyCol2") + val df2 = Seq((1, 4), (2, 1)).toDF("keyCol1", "keyCol3") + + val df3 = df1.join(df2, df1("keyCol1") === df2("keyCol1")).select(df1("keyCol1"), $"keyCol3") + + checkAnswer( + df3.join(df1, df3("keyCol3") === df1("keyCol1") && df1("keyCol1") === df3("keyCol3")), + Row(2, 1, 1, 3) :: Nil) + } + + test("[SPARK-10838] self join - conflicting attributes in condition - exception") { + val df1 = Seq((1, 3), (2, 1)).toDF("keyCol1", "keyCol2") + val df2 = Seq((1, 4), (2, 1)).toDF("keyCol1", "keyCol3") + + val df3 = df1.join(df2, df1("keyCol1") === df2("keyCol1")).select(df1("keyCol1"), $"keyCol3") + val df4 = df2.as("df4") + + checkAnswer( + df3.join(df4, df3("keyCol3") === df4("keyCol1") && df3("keyCol3") === df4("keyCol1")), + Row(2, 1, 1, 4) :: Nil) + } + test("broadcast join hint") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")