Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.function._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* :: Experimental ::
Expand Down Expand Up @@ -83,7 +83,6 @@ class Dataset[T] private[sql](

/**
* Returns the schema of the encoded form of the objects in this [[Dataset]].
*
* @since 1.6.0
*/
def schema: StructType = resolvedTEncoder.schema
Expand Down Expand Up @@ -185,7 +184,6 @@ class Dataset[T] private[sql](
* .transform(featurize)
* .transform(...)
* }}}
*
* @since 1.6.0
*/
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
Expand Down Expand Up @@ -453,6 +451,21 @@ class Dataset[T] private[sql](
c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]

/**
* Returns a new [[Dataset]] by sampling a fraction of records.
* @since 1.6.0
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] =
withPlan(Sample(0.0, fraction, withReplacement, seed, _))

/**
* Returns a new [[Dataset]] by sampling a fraction of records, using a random seed.
* @since 1.6.0
*/
def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = {
sample(withReplacement, fraction, Utils.random.nextLong)
}

/* **************** *
* Set operations *
* **************** */
Expand Down Expand Up @@ -511,13 +524,17 @@ class Dataset[T] private[sql](
* types as well as working with relational data where either side of the join has column
* names in common.
*
* @param other Right side of the join.
* @param condition Join expression.
* @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`.
* @since 1.6.0
*/
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to update the documentation to explain what options are available for joinType.

val left = this.logicalPlan
val right = other.logicalPlan

val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr)))
val joined = sqlContext.executePlan(Join(left, right, joinType =
JoinType(joinType), Some(condition.expr)))
val leftOutput = joined.analyzed.output.take(left.output.length)
val rightOutput = joined.analyzed.output.takeRight(right.output.length)

Expand All @@ -540,6 +557,18 @@ class Dataset[T] private[sql](
}
}

/**
* Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair
* where `condition` evaluates to true.
*
* @param other Right side of the join.
* @param condition Join expression.
* @since 1.6.0
*/
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
joinWith(other, condition, "inner")
}

/* ************************** *
* Gather to Driver Actions *
* ************************** */
Expand Down Expand Up @@ -584,7 +613,6 @@ class Dataset[T] private[sql](
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @since 1.6.0
*/
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
Expand All @@ -594,7 +622,6 @@ class Dataset[T] private[sql](
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @since 1.6.0
*/
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
Expand Down
36 changes: 29 additions & 7 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq(1, 2).toDS().as("b")

checkAnswer(
ds1.joinWith(ds2, $"a.value" === $"b.value"),
ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"),
(1, 1), (2, 2))
}

test("joinWith, expression condition") {
val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
val ds2 = Seq(("a", 1), ("b", 2)).toDS()
test("joinWith, expression condition, outer join") {
val nullInteger = null.asInstanceOf[Integer]
val nullString = null.asInstanceOf[String]
val ds1 = Seq(ClassNullableData("a", 1),
ClassNullableData("c", 3)).toDS()
val ds2 = Seq(("a", new Integer(1)),
("b", new Integer(2))).toDS()

checkAnswer(
ds1.joinWith(ds2, $"_1" === $"a"),
(ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
ds1.joinWith(ds2, $"_1" === $"a", "outer"),
(ClassNullableData("a", 1), ("a", new Integer(1))),
(ClassNullableData("c", 3), (nullString, nullInteger)),
(ClassNullableData(nullString, nullInteger), ("b", new Integer(2))))
}

test("joinWith tuple with primitive, expression") {
Expand Down Expand Up @@ -225,7 +231,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"),
((("a", 1), ("a", 1)), ("a", 1)),
((("b", 2), ("b", 2)), ("b", 2)))

}

test("groupBy function, keys") {
Expand Down Expand Up @@ -367,6 +372,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1 -> "a", 2 -> "bc", 3 -> "d")
}

test("sample with replacement") {
val n = 100
val data = sparkContext.parallelize(1 to n, 2).toDS()
checkAnswer(
data.sample(withReplacement = true, 0.05, seed = 13),
5, 10, 52, 73)
}

test("sample without replacement") {
val n = 100
val data = sparkContext.parallelize(1 to n, 2).toDS()
checkAnswer(
data.sample(withReplacement = false, 0.05, seed = 13),
3, 17, 27, 58, 62)
}

test("SPARK-11436: we should rebind right encoder when join 2 datasets") {
val ds1 = Seq("1", "2").toDS().as("a")
val ds2 = Seq(2, 3).toDS().as("b")
Expand Down Expand Up @@ -440,6 +461,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {


case class ClassData(a: String, b: Int)
case class ClassNullableData(a: String, b: Integer)

/**
* A class used to test serialization using encoders. This class throws exceptions when using
Expand Down