Skip to content

Commit 1135773

Browse files
committed
Support partial aggregation for reduceGroups.
1 parent 1832423 commit 1135773

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ou
2525
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
2626
import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.execution.QueryExecution
28+
import org.apache.spark.sql.expressions.Aggregator
2829

2930
/**
3031
* :: Experimental ::
@@ -177,10 +178,33 @@ class KeyValueGroupedDataset[K, V] private[sql](
177178
* @since 1.6.0
178179
*/
179180
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
180-
val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
181+
val encoder = encoderFor[V]
182+
val intEncoder: ExpressionEncoder[Int] = ExpressionEncoder()
183+
val aggregator: TypedColumn[V, V] = new Aggregator[V, (Int, V), V] {
184+
def bufferEncoder: Encoder[(Int, V)] = ExpressionEncoder.tuple(intEncoder, encoder)
185+
def outputEncoder: Encoder[V] = encoder
181186

182-
implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc)
183-
flatMapGroups(func)
187+
def zero: (Int, V) = (0, null.asInstanceOf[V])
188+
def reduce(reducedValue: (Int, V), value: V): (Int, V) = {
189+
if (reducedValue._1 == 0) {
190+
(1, value)
191+
} else {
192+
(1, f(reducedValue._2, value))
193+
}
194+
}
195+
def merge(buf1: (Int, V), buf2: (Int, V)): (Int, V) = {
196+
if (buf1._1 == 0) {
197+
buf2
198+
} else if (buf2._2 == 0) {
199+
buf1
200+
} else {
201+
(1, f(buf1._2, buf2._2))
202+
}
203+
}
204+
def finish(result: (Int, V)): V = result._2
205+
}.toColumn
206+
207+
agg(aggregator)
184208
}
185209

186210
/**

0 commit comments

Comments
 (0)