Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import org.apache.spark.sql.functions._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sort the imports properly


import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
Expand All @@ -26,7 +28,7 @@ import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -493,11 +495,12 @@ class Dataset[T] private[sql](
*
* @since 1.6.0
*/
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to update the documentation to explain what options are available for joinType.

val left = this.logicalPlan
val right = other.logicalPlan

val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr)))
val joined = sqlContext.executePlan(Join(left, right, joinType =
JoinType(joinType), Some(condition.expr)))
val leftOutput = joined.analyzed.output.take(left.output.length)
val rightOutput = joined.analyzed.output.takeRight(right.output.length)

Expand All @@ -520,6 +523,25 @@ class Dataset[T] private[sql](
}
}

/**
* Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair
* where `condition` evaluates to true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missed a period

*
* @since 1.6.0
*/
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
joinWith(other, condition, "inner")
}

/**
* Joins this [[Dataset]] returning a [[Tuple2]] for each pair using cartesian join
*
* Note that cartesian joins are very expensive without an extra filter that can be pushed down.
*
* @since 1.6.0
*/
def joinWith[U](other: Dataset[U]): Dataset[(T, U)] = joinWith (other, lit(true), "inner")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the extra space

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'd maybe just remove this for now -- since cartesian joins are too expensive.


/* ************************** *
* Gather to Driver Actions *
* ************************** */
Expand Down
27 changes: 17 additions & 10 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq(1, 2).toDS().as("b")

checkAnswer(
ds1.joinWith(ds2, $"a.value" === $"b.value"),
ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"),
(1, 1), (2, 2))
}

test("joinWith, expression condition") {
val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
val ds2 = Seq(("a", 1), ("b", 2)).toDS()
test("joinWith, expression condition, outer join") {
val nullInteger = null.asInstanceOf[Integer]
val nullString = null.asInstanceOf[String]
val ds1 = Seq(ClassNullableData("a", new Integer(1)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can just pass in 1, and compile will auto box for us.

ClassNullableData("c", new Integer(3))).toDS()
val ds2 = Seq(("a", new Integer(1)),
("b", new Integer(2))).toDS()

checkAnswer(
ds1.joinWith(ds2, $"_1" === $"a"),
(ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
ds1.joinWith(ds2, $"_1" === $"a", "outer"),
(ClassNullableData("a", new Integer(1)), ("a", new Integer(1))),
(ClassNullableData("c", new Integer(3)), (nullString, nullInteger)),
(ClassNullableData(nullString, nullInteger), ("b", new Integer(2))))
}

test("joinWith tuple with primitive, expression") {
Expand Down Expand Up @@ -350,7 +356,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {

test("self join") {
val ds = Seq("1", "2").toDS().as("a")
val joined = ds.joinWith(ds, lit(true))
val joined = ds.joinWith(ds)
checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
}

Expand All @@ -370,7 +376,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("Kryo encoder self join") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
assert(ds.joinWith(ds, lit(true)).collect().toSet ==
assert(ds.joinWith(ds).collect().toSet ==
Set(
(KryoData(1), KryoData(1)),
(KryoData(1), KryoData(2)),
Expand All @@ -389,7 +395,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("Java encoder self join") {
implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
val ds = Seq(JavaData(1), JavaData(2)).toDS()
assert(ds.joinWith(ds, lit(true)).collect().toSet ==
assert(ds.joinWith(ds).collect().toSet ==
Set(
(JavaData(1), JavaData(1)),
(JavaData(1), JavaData(2)),
Expand All @@ -403,7 +409,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()

checkAnswer(
ds1.joinWith(ds2, lit(true)),
ds1.joinWith(ds2),
((nullInt, "1"), (nullInt, "1")),
((new java.lang.Integer(22), "2"), (nullInt, "1")),
((nullInt, "1"), (new java.lang.Integer(22), "2")),
Expand All @@ -413,6 +419,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {


case class ClassData(a: String, b: Int)
case class ClassNullableData(a: String, b: Integer)

/**
* A class used to test serialization using encoders. This class throws exceptions when using
Expand Down