@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ou
2525import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , CreateStruct }
2626import org .apache .spark .sql .catalyst .plans .logical ._
2727import org .apache .spark .sql .execution .QueryExecution
28+ import org .apache .spark .sql .expressions .Aggregator
2829
2930/**
3031 * :: Experimental ::
@@ -177,10 +178,33 @@ class KeyValueGroupedDataset[K, V] private[sql](
177178 * @since 1.6.0
178179 */
179180 def reduceGroups (f : (V , V ) => V ): Dataset [(K , V )] = {
180- val func = (key : K , it : Iterator [V ]) => Iterator ((key, it.reduce(f)))
181+ val encoder = encoderFor[V ]
182+ val intEncoder : ExpressionEncoder [Int ] = ExpressionEncoder ()
183+ val aggregator : TypedColumn [V , V ] = new Aggregator [V , (Int , V ), V ] {
184+ def bufferEncoder : Encoder [(Int , V )] = ExpressionEncoder .tuple(intEncoder, encoder)
185+ def outputEncoder : Encoder [V ] = encoder
181186
182- implicit val resultEncoder = ExpressionEncoder .tuple(kExprEnc, vExprEnc)
183- flatMapGroups(func)
187+ def zero : (Int , V ) = (0 , null .asInstanceOf [V ])
188+ def reduce (reducedValue : (Int , V ), value : V ): (Int , V ) = {
189+ if (reducedValue._1 == 0 ) {
190+ (1 , value)
191+ } else {
192+ (1 , f(reducedValue._2, value))
193+ }
194+ }
195+ def merge (buf1 : (Int , V ), buf2 : (Int , V )): (Int , V ) = {
196+ if (buf1._1 == 0 ) {
197+ buf2
198+ } else if (buf2._2 == 0 ) {
199+ buf1
200+ } else {
201+ (1 , f(buf1._2, buf2._2))
202+ }
203+ }
204+ def finish (result : (Int , V )): V = result._2
205+ }.toColumn
206+
207+ agg(aggregator)
184208 }
185209
186210 /**
0 commit comments