Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
*/
def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))

/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did.
*/
def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd))

// Double RDD functions

/** Add up the elements in this RDD. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.union(other.rdd))

/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did.
*/
def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.intersection(other.rdd))


// first() has to be overridden here so that the generated method has the signature
// 'public scala.Tuple2 first()'; if the trait's definition is used,
// then the method has the signature 'public java.lang.Object first()',
Expand Down
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
*/
def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))


/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did.
*/
def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd))

/**
* Return an RDD with the elements from `this` that are not in `other`.
*
Expand Down
31 changes: 31 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,37 @@ public void sparkContextUnion() {
Assert.assertEquals(4, pUnion.count());
}

@SuppressWarnings("unchecked")
@Test
public void intersection() {
List<Integer> ints1 = Arrays.asList(1, 10, 2, 3, 4, 5);
List<Integer> ints2 = Arrays.asList(1, 6, 2, 3, 7, 8);
JavaRDD<Integer> s1 = sc.parallelize(ints1);
JavaRDD<Integer> s2 = sc.parallelize(ints2);

JavaRDD<Integer> intersections = s1.intersection(s2);
Assert.assertEquals(3, intersections.count());

ArrayList<Integer> list = new ArrayList<Integer>();
JavaRDD<Integer> empty = sc.parallelize(list);
JavaRDD<Integer> emptyIntersection = empty.intersection(s2);
Assert.assertEquals(0, emptyIntersection.count());

List<Double> doubles = Arrays.asList(1.0, 2.0);
JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
JavaDoubleRDD dIntersection = d1.intersection(d2);
Assert.assertEquals(2, dIntersection.count());

List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
pairs.add(new Tuple2<Integer, Integer>(1, 2));
pairs.add(new Tuple2<Integer, Integer>(3, 4));
JavaPairRDD<Integer, Integer> p1 = sc.parallelizePairs(pairs);
JavaPairRDD<Integer, Integer> p2 = sc.parallelizePairs(pairs);
JavaPairRDD<Integer, Integer> pIntersection = p1.intersection(p2);
Assert.assertEquals(2, pIntersection.count());
}

@Test
public void sortByKey() {
List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,22 @@ def union(self, other):
return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
self.ctx.serializer)

def intersection(self, other):
"""
Return the intersection of this RDD and another one.

Note: The output will not contain any duplicate elements, even if the
input RDDs did.

>>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5])
>>> rdd2 = sc.parallelize([1, 6, 2, 3, 7, 8])
>>> rdd1.intersection(rdd2).collect()
[1, 2, 3]
"""
return self.map(lambda v: (v, None)).cogroup(
other.map(lambda v: (v, None))).filter(
lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)).keys()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably slightly nicer to write this like this:

return self.map(lambda v: (v, None)) \
    .cogroup(other.map(lambda v: (v, None))) \
    .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \
    .keys()

Or put parens around the whole thing to avoid Python thinking lines ended (but we used the backslash style before).

Other than that it looks good.


def _reserialize(self):
if self._jrdd_deserializer == self.ctx.serializer:
return self
Expand Down