Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add ReduceAggregator.
  • Loading branch information
viirya committed Jul 18, 2016
commit 7e8d8c116552642573cc89bd11fc2e82f2a0f82a
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ou
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.ReduceAggregator

/**
* :: Experimental ::
Expand Down Expand Up @@ -179,30 +179,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
*/
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
val encoder = encoderFor[V]
val intEncoder: ExpressionEncoder[Int] = ExpressionEncoder()
val aggregator: TypedColumn[V, V] = new Aggregator[V, (Int, V), V] {
def bufferEncoder: Encoder[(Int, V)] = ExpressionEncoder.tuple(intEncoder, encoder)
def outputEncoder: Encoder[V] = encoder

def zero: (Int, V) = (0, null.asInstanceOf[V])
def reduce(reducedValue: (Int, V), value: V): (Int, V) = {
if (reducedValue._1 == 0) {
(1, value)
} else {
(1, f(reducedValue._2, value))
}
}
def merge(buf1: (Int, V), buf2: (Int, V)): (Int, V) = {
if (buf1._1 == 0) {
buf2
} else if (buf2._2 == 0) {
buf1
} else {
(1, f(buf1._2, buf2._2))
}
}
def finish(result: (Int, V)): V = result._2
}.toColumn
val aggregator: TypedColumn[V, V] = new ReduceAggregator(f, encoder).toColumn

agg(aggregator)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.expressions

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

/**
* :: Experimental ::
* A generic class for reduce aggregations, which accepts a reduce function that can be used to take
* all of the elements of a group and reduce them to a single value.
*
* @tparam T The input and output type for the reduce function.
* @param func The reduce aggregation function.
* @param encoder The encoder for the input and output type of the reduce function.
* @since 2.1.0
*/
@Experimental
private[sql] class ReduceAggregator[T](func: (T, T) => T, encoder: ExpressionEncoder[T])
extends Aggregator[T, (Boolean, T), T] {

/**
* A zero value for this aggregation. It is represented as a Tuple2. The first element of the
* tuple is a false boolean value indicating the buffer is not initialized. The second element
* is initialized as a null value.
* @since 2.1.0
*/
override def zero: (Boolean, T) = (false, null.asInstanceOf[T])

override def bufferEncoder: Encoder[(Boolean, T)] =
ExpressionEncoder.tuple(ExpressionEncoder[Boolean](), encoder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Encoders.scalaBoolean?


override def outputEncoder: Encoder[T] = encoder

/**
* Combine two values to produce a new value. If the buffer `b` is not initialized, it simply
* takes the value of `a` and set the initialization flag to `true`.
* @since 2.1.0
*/
override def reduce(b: (Boolean, T), a: T): (Boolean, T) = {
if (b._1) {
(true, func(b._2, a))
} else {
(true, a)
}
}

/**
* Merge two intermediate values. As it is possibly that the buffer is just the `zero` value
* coming from empty partition, it checks if the buffers are initialized, and only performs
* merging when they are initialized both.
* @since 2.1.0
*/
override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = {
if (!b1._1) {
b2
} else if (!b2._1) {
b1
} else {
(true, func(b1._2, b2._2))
}
}

/**
* Transform the output of the reduction. Simply output the value in the buffer.
* @since 2.1.0
*/
override def finish(reduction: (Boolean, T)): T = {
reduction._2
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.ReduceAggregator

class ReduceAggregatorSuite extends SparkFunSuite {
test("zero value") {
val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
val func = (v1: Int, v2: Int) => v1 + v2
val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder)
assert(aggregator.zero == (false, null))
}

test("reduce, merge and finish") {
val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
val func = (v1: Int, v2: Int) => v1 + v2
val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder)

val firstReduce = aggregator.reduce(aggregator.zero, 1)
assert(firstReduce == (true, 1))

val secondReduce = aggregator.reduce(firstReduce, 2)
assert(secondReduce == (true, 3))

val thirdReduce = aggregator.reduce(secondReduce, 3)
assert(thirdReduce == (true, 6))

val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce)
assert(mergeWithZero1 == (true, 1))

val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero)
assert(mergeWithZero2 == (true, 3))

val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce)
assert(mergeTwoReduced == (true, 4))

assert(aggregator.finish(firstReduce)== 1)
assert(aggregator.finish(secondReduce) == 3)
assert(aggregator.finish(thirdReduce) == 6)
assert(aggregator.finish(mergeWithZero1) == 1)
assert(aggregator.finish(mergeWithZero2) == 3)
assert(aggregator.finish(mergeTwoReduced) == 4)
}
}