Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add reduce to GroupedDataset
  • Loading branch information
cloud-fan committed Nov 11, 2015
commit 7504c6790bf9bad143bce9f259e1ce98a5b40043
21 changes: 15 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

package org.apache.spark.sql

import java.util.{Iterator => JIterator}

import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _}
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
Expand Down Expand Up @@ -127,15 +125,26 @@ class GroupedDataset[K, T] private[sql](
*/
def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
new Dataset[U](
sqlContext,
MapGroups(func, groupingAttributes, logicalPlan))
flatMap(func)
}

def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
map((key, data) => f.call(key, data.asJava))(encoder)
}

/**
* Reduces the elements of each group of data using the specified binary function.
* The given function must be commutative and associative or the result may be non-deterministic.
*/
def reduce(f: (T, T) => T): Dataset[(K, T)] = {
val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f))
flatMap(func)(ExpressionEncoder.tuple(kEnc, tEnc))
}

def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = {
reduce(f.call _)
}

// To ensure valid overloading.
protected def agg(expr: Column, exprs: Column*): DataFrame =
groupedData.agg(expr, exprs: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,17 @@ public Iterable<String> call(Integer key, Iterator<String> values) throws Except

Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList());

Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new ReduceFunction<String>() {
@Override
public String call(String v1, String v2) throws Exception {
return v1 + v2;
}
});

Assert.assertEquals(
Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")),
reduced.collectAsList());

List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
"a", "30", "b", "3", "c", "1")
}

test("groupBy function, reduce") {
val ds = Seq("abc", "xzy", "hello").toDS()
val agged = ds.groupBy(_.length).reduce(_ + _)

checkAnswer(
agged,
3 -> "abcxyz", 5 -> "hello")
}

test("groupBy columns, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1")
Expand Down