1717
1818package org .apache .spark .sql
1919
20+ import com .google .common .primitives .Ints
21+
2022import org .apache .spark .sql .TypedImperativeAggregateSuite .TypedMax
2123import org .apache .spark .sql .catalyst .InternalRow
22- import org .apache .spark .sql .catalyst .expressions .{BoundReference , Expression , UnsafeRow }
23- import org .apache .spark .sql .catalyst .expressions .aggregate .{ TypedImperativeAggregate }
24+ import org .apache .spark .sql .catalyst .expressions .{BoundReference , Expression , GenericMutableRow , SpecificMutableRow , UnsafeRow }
25+ import org .apache .spark .sql .catalyst .expressions .aggregate .TypedImperativeAggregate
2426import org .apache .spark .sql .execution .aggregate .SortAggregateExec
27+ import org .apache .spark .sql .expressions .Window
2528import org .apache .spark .sql .functions ._
2629import org .apache .spark .sql .test .SharedSQLContext
27- import org .apache .spark .sql .types .{AbstractDataType , DataType , IntegerType }
30+ import org .apache .spark .sql .types .{AbstractDataType , BinaryType , DataType , IntegerType , LongType }
2831
2932class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
3033
3134 import testImplicits ._
3235
33- private val data = Seq (( 1 , 0 ), ( 3 , 1 ), ( 2 , 0 ), ( 6 , 3 ), ( 3 , 1 ), ( 4 , 1 ), ( 5 , 0 ) )
36+ private val random = new java.util. Random ( )
3437
38+ private val data = (0 until 1000 ).map { _ =>
39+ (random.nextInt(10 ), random.nextInt(100 ))
40+ }
3541 test(" aggregate with object aggregate buffer" ) {
3642 val agg = new TypedMax (BoundReference (0 , IntegerType , nullable = false ))
3743
@@ -55,37 +61,66 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
5561
5662 assert(mergeBuffer.value == data.map(_._1).max)
5763 assert(agg.eval(mergeBuffer) == data.map(_._1).max)
64+
65+ // Tests low level eval(row: InternalRow) API.
66+ val array : Array [Any ] = Array (mergeBuffer)
67+ val row = new GenericMutableRow (array)
68+
69+ // Evaluates directly on row consist of aggregation buffer object.
70+ assert(agg.eval(row) == data.map(_._1).max)
71+
72+ // Serializes the aggregation buffer object and then evals.
73+ agg.serializeAggregateBufferInPlace(row)
74+ assert(agg.eval(row) == data.map(_._1).max)
75+ }
76+
77+ test(" supports SpecificMutableRow as mutable row" ) {
78+ val aggregationBufferSchema = Seq (IntegerType , LongType , BinaryType , IntegerType )
79+ val aggBufferOffset = 2
80+ val inputBufferObject = 1
81+ val buffer = new SpecificMutableRow (aggregationBufferSchema)
82+ val agg = new TypedMax (BoundReference (inputBufferObject, IntegerType , nullable = false ))
83+ .withNewMutableAggBufferOffset(aggBufferOffset)
84+ .withNewInputAggBufferOffset(inputBufferObject)
85+
86+ agg.initialize(buffer)
87+ data.foreach { kv =>
88+ val input = InternalRow (kv._1, kv._2)
89+ agg.update(buffer, input)
90+ }
91+ assert(agg.eval(buffer) == data.map(_._2).max)
5892 }
5993
6094 test(" dataframe aggregate with object aggregate buffer, should not use HashAggregate" ) {
6195 val df = data.toDF(" a" , " b" )
6296 val max = new TypedMax ($" a" .expr)
6397
64- // Always use SortAggregateExec instead of HashAggregateExec for planning even if the aggregate
65- // buffer attributes are mutable fields (every field can be mutated inline like int, long...)
66- val allFieldsMutable = max.aggBufferSchema.map(_.dataType).forall(UnsafeRow .isMutable)
98+ // Always uses SortAggregateExec
6799 val sparkPlan = df.select(Column (max.toAggregateExpression())).queryExecution.sparkPlan
68- assert(allFieldsMutable == true && sparkPlan.isInstanceOf [SortAggregateExec ])
100+ assert(sparkPlan.isInstanceOf [SortAggregateExec ])
69101 }
70102
71103 test(" dataframe aggregate with object aggregate buffer, no group by" ) {
72- val df = data.toDF(" a" , " b" ).coalesce(2 )
73- checkAnswer(
74- df.select(typedMax($" a" ), count($" a" ), typedMax($" b" ), count($" b" )),
75- Seq (Row (6 , 7 , 3 , 7 ))
76- )
104+ val df = data.toDF(" key" , " value" ).coalesce(2 )
105+ val query = df.select(typedMax($" key" ), count($" key" ), typedMax($" value" ), count($" value" ))
106+ val maxKey = data.map(_._1).max
107+ val countKey = data.size
108+ val maxValue = data.map(_._2).max
109+ val countValue = data.size
110+ val expected = Seq (Row (maxKey, countKey, maxValue, countValue))
111+ checkAnswer(query, expected)
77112 }
78113
79114 test(" dataframe aggregate with object aggregate buffer, with group by" ) {
80- val df = data.toDF(" a " , " b " ).coalesce(2 )
81- checkAnswer(
82- df .groupBy($ " b " ).agg(typedMax($ " a " ), count($ " a " ), typedMax($ " a " )),
83- Seq (
84- Row ( 0 , 5 , 3 , 5 ),
85- Row ( 1 , 4 , 3 , 4 ),
86- Row (3 , 6 , 1 , 6 )
87- )
88- )
115+ val df = data.toDF(" value " , " key " ).coalesce(2 )
116+ val query = df.groupBy($ " key " ).agg(typedMax($ " value " ), count($ " value " ), typedMax($ " value " ))
117+ val expected = data .groupBy(_._2).toSeq.map { group =>
118+ val (key, values) = group
119+ val valueMax = values.map(_._1).max
120+ val countValue = values.size
121+ Row (key, valueMax, countValue, valueMax )
122+ }
123+ checkAnswer(query, expected )
89124 }
90125
91126 test(" dataframe aggregate with object aggregate buffer, empty inputs, no group by" ) {
@@ -102,6 +137,36 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
102137 Seq .empty[Row ])
103138 }
104139
140+ test(" TypedImperativeAggregate should not break Window function" ) {
141+ val df = data.toDF(" key" , " value" )
142+ // OVER (PARTITION BY a ORDER BY b ROW BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
143+ val w = Window .orderBy(" value" ).partitionBy(" key" ).rowsBetween(Long .MinValue , 0 )
144+
145+ val query = df.select(sum($" key" ).over(w), typedMax($" key" ).over(w), sum($" value" ).over(w),
146+ typedMax($" value" ).over(w))
147+
148+ val expected = data.groupBy(_._1).toSeq.flatMap { group =>
149+ val (key, values) = group
150+ val sortedValues = values.map(_._2).sorted
151+
152+ var outputRows = Seq .empty[Row ]
153+ var i = 0
154+ while (i < sortedValues.size) {
155+ val unboundedPrecedingAndCurrent = sortedValues.slice(0 , i + 1 )
156+ val sumKey = key * unboundedPrecedingAndCurrent.size
157+ val maxKey = key
158+ val sumValue = unboundedPrecedingAndCurrent.sum
159+ val maxValue = unboundedPrecedingAndCurrent.max
160+
161+ outputRows :+= Row (sumKey, maxKey, sumValue, maxValue)
162+ i += 1
163+ }
164+
165+ outputRows
166+ }
167+ checkAnswer(query, expected)
168+ }
169+
105170 private def typedMax (column : Column ): Column = {
106171 val max = TypedMax (column.expr)
107172 Column (max.toAggregateExpression())
@@ -159,14 +224,10 @@ object TypedImperativeAggregateSuite {
159224
160225 override def aggregationBufferClass : Class [MaxValue ] = classOf [MaxValue ]
161226
162- override def serialize (buffer : MaxValue ): Any = buffer.value
227+ override def serialize (buffer : MaxValue ): Array [ Byte ] = Ints .toByteArray( buffer.value)
163228
164- override def aggregationBufferStorageFormatSqlType : DataType = IntegerType
165-
166- override def deserialize (storageFormat : Any ): MaxValue = {
167- storageFormat match {
168- case i : Int => new MaxValue (i)
169- }
229+ override def deserialize (storageFormat : Array [Byte ]): MaxValue = {
230+ new MaxValue (Ints .fromByteArray(storageFormat))
170231 }
171232 }
172233
0 commit comments