Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add methods to facilitate equi-join on multiple joining keys.
  • Loading branch information
viirya committed Jun 3, 2015
commit cc90015e12a6229ce72196e4747ca3d0655abeb4
32 changes: 23 additions & 9 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,30 +509,42 @@ def join(self, other, joinExprs=None, joinType=None):
The following performs a full outer join between ``df1`` and ``df2``.

:param other: Right side of the join
:param joinExprs: a string for join column name, or a join expression (Column).
If joinExprs is a string indicating the name of the join column,
the column must exist on both sides, and this performs an inner equi-join.
:param joinExprs: a string for join column name, a list of column names,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's change the argument name in python as explained in the jira ticket

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

, a join expression (Column) or a list of Columns.
If joinExprs is a string or a list of string indicating the name of the join column(s),
the column(s) must exist on both sides, and this performs an inner equi-join.
:param joinType: str, default 'inner'.
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.

>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]

>>> cond = [df.name == df3.name, df.age == df3.age]
>>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()
[Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)]

>>> df.join(df2, 'name').select(df.name, df2.height).collect()
[Row(name=u'Bob', height=85)]

>>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()
[Row(name=u'Bob', age=5)]
"""

if joinExprs is None:
if joinExprs is not None and not isinstance(joinExprs, list):
joinExprs = [joinExprs]

if joinExprs is None or len(joinExprs) == 0:
jdf = self._jdf.join(other._jdf)
elif isinstance(joinExprs, basestring):
jdf = self._jdf.join(other._jdf, joinExprs)

if isinstance(joinExprs[0], basestring):
jdf = self._jdf.join(other._jdf, self._jseq(joinExprs))
else:
assert isinstance(joinExprs, Column), "joinExprs should be Column"
assert isinstance(joinExprs[0], Column), "joinExprs should be Column or list of Column"
if joinType is None:
jdf = self._jdf.join(other._jdf, joinExprs._jc)
jdf = self._jdf.join(other._jdf, self._jcols(joinExprs), "inner")
else:
assert isinstance(joinType, basestring), "joinType should be basestring"
jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
jdf = self._jdf.join(other._jdf, self._jcols(joinExprs), joinType)
return DataFrame(jdf, self.sql_ctx)

@ignore_unicode_prefix
Expand Down Expand Up @@ -1291,6 +1303,8 @@ def _test():
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
globs['df3'] = sc.parallelize([Row(name='Alice', age=2),
Row(name='Bob', age=5)]).toDF()
globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
Row(name='Bob', age=5, height=None),
Row(name='Tom', age=None, height=None),
Expand Down
43 changes: 35 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,22 +395,35 @@ class DataFrame private[sql](
* @since 1.4.0
*/
def join(right: DataFrame, usingColumn: String): DataFrame = {
join(right, Seq(usingColumn))
}

def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add javadoc

// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch.
val joined = sqlContext.executePlan(
Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join]

// Project only one of the join column.
val joinedCol = joined.right.resolve(usingColumn)
// Project only one of the join columns.
val joinedCols = usingColumns.map(col => joined.right.resolve(col))
val condition = usingColumns.map { col =>
catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col))
}.foldLeft[Option[catalyst.expressions.BinaryExpression]](None) { (opt, eqTo) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be simplifed into a reduceOption right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. yes. Did not notice that.

opt match {
case Some(cond) =>
Some(catalyst.expressions.And(cond, eqTo))
case None =>
Some(eqTo)
}
}

Project(
joined.output.filterNot(_ == joinedCol),
joined.output.filterNot(joinedCols.contains(_)),
Join(
joined.left,
joined.right,
joinType = Inner,
Some(catalyst.expressions.EqualTo(
joined.left.resolve(usingColumn),
joined.right.resolve(usingColumn))))
condition)
)
}

Expand All @@ -425,7 +438,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
def join(right: DataFrame, joinExprs: Column): DataFrame = join(right, joinExprs, "inner")
def join(right: DataFrame, joinExprs: Column): DataFrame = join(right, Seq(joinExprs), "inner")

/**
* Join with another [[DataFrame]], using the given join expression. The following performs
Expand All @@ -448,6 +461,10 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = {
join(right, Seq(joinExprs), joinType)
}

def join(right: DataFrame, joinExprs: Seq[Column], joinType: String): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should remove this one for scala. basically it's not that big of a deal for scala users to write join1 and join2 and join3, whereas it is much harder for python users to do that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for def join(self, other, joinExprs=None, joinType=None) we don't support given a list of Column as joinExprs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for Python we can support it.

// Note that in this function, we introduce a hack in the case of self-join to automatically
// resolve ambiguous join conditions into ones that might make sense [SPARK-6231].
// Consider this case: df.join(df, df("key") === df("key"))
Expand All @@ -458,7 +475,17 @@ class DataFrame private[sql](

// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
// After the cloning, left and right side will have distinct expression ids.
val plan = Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))
val condition = joinExprs
.foldLeft[Option[catalyst.expressions.Expression]](None) { (opt, condNext) =>
opt match {
case Some(cond) =>
Some(catalyst.expressions.And(cond, condNext.expr))
case None =>
Some(condNext.expr)
}
}

val plan = Join(logicalPlan, right.logicalPlan, JoinType(joinType), condition)
.queryExecution.analyzed.asInstanceOf[Join]

// If auto self join alias is disabled, return the plan.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ class DataFrameJoinSuite extends QueryTest {
Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil)
}

test("join - join using multiple columns") {
val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str")
val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str")

checkAnswer(
df.join(df2, Seq("int", "int2")),
Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil)
}

test("join - join using self join") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")

Expand All @@ -52,6 +61,15 @@ class DataFrameJoinSuite extends QueryTest {
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
}

test("join - self join multiple columns") {
val df1 = testData.as('df1)
val df2 = testData.as('df2)

checkAnswer(
df1.join(df2, Seq($"df1.key" === $"df2.key", $"df1.value" === $"df2.value"), "inner"),
sql("SELECT a.key, a.value, b.key, b.value FROM testData a JOIN testData b ON a.key = b.key AND a.value = b.value").collect().toSeq)
}

test("join - using aliases after self join") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
checkAnswer(
Expand Down