Skip to content

Commit d3108ab

Browse files
committed
fix review comments
1 parent 0fdc1ea commit d3108ab

File tree

2 files changed

+104
-71
lines changed

2 files changed

+104
-71
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,8 @@ abstract class DeclarativeAggregate
415415
* 2. Upon each input row, the framework calls
416416
* `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T.
417417
* 3. After processing all rows of current group (group by key), the framework will serialize
418-
* aggregation buffer object T to SparkSQL internally supported underlying storage format, and
419-
* persist the serializable format to disk if needed.
418+
* aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte]
419+
* to disk if needed.
420420
* 4. The framework moves on to next group, until all groups have been processed.
421421
*
422422
* Shuffling exchange data to Reducer tasks...
@@ -426,7 +426,7 @@ abstract class DeclarativeAggregate
426426
* 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
427427
* buffer object (type T) for merging.
428428
* 2. For each aggregation output of Stage 1, The framework de-serializes the storage
429-
* format and generates one input aggregation object (type T).
429+
* format (Array[Byte]) and produces one input aggregation object (type T).
430430
* 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit`
431431
* to merge the input aggregation object into aggregation buffer object.
432432
* 4. After processing all input aggregation objects of current group (group by key), the framework
@@ -474,39 +474,11 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
474474
/** Returns the class of aggregation buffer object */
475475
def aggregationBufferClass: Class[T]
476476

477-
/** Serializes the aggregation buffer object T to Spark-sql internally supported storage format */
478-
def serialize(buffer: T): Any
477+
/** Serializes the aggregation buffer object T to Array[Byte] */
478+
def serialize(buffer: T): Array[Byte]
479479

480-
/** De-serializes the storage format, and produces aggregation buffer object T */
481-
def deserialize(storageFormat: Any): T
482-
483-
/**
484-
* Returns the aggregation-buffer-object storage format's Sql type.
485-
*
486-
* Here is a list of supported storage format and corresponding Sql type:
487-
*
488-
* {{{
489-
* aggregation buffer object's Storage format | storage format's Sql type
490-
* ------------------------------------------------------------------------------------------
491-
* Array[Byte] (*) | BinaryType (*)
492-
* Null | NullType
493-
* Boolean | BooleanType
494-
* Byte | ByteType
495-
* Short | ShortType
496-
* Int | IntegerType
497-
* Long | LongType
498-
* Float | FloatType
499-
* Double | DoubleType
500-
* org.apache.spark.sql.types.Decimal | DecimalType
501-
* org.apache.spark.unsafe.types.UTF8String | StringType
502-
* org.apache.spark.unsafe.types.CalendarInterval| CalendarIntervalType
503-
* org.apache.spark.sql.catalyst.util.MapData | MapType
504-
* org.apache.spark.sql.catalyst.util.ArrayData | ArrayType
505-
* org.apache.spark.sql.catalyst.InternalRow |
506-
* }}}
507-
*
508-
*/
509-
def aggregationBufferStorageFormatSqlType: DataType
480+
/** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */
481+
def deserialize(storageFormat: Array[Byte]): T
510482

511483
final override def initialize(buffer: MutableRow): Unit = {
512484
val bufferObject = createAggregationBuffer()
@@ -519,29 +491,29 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
519491
}
520492

521493
final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
522-
val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T]
523-
val inputObject = deserialize(field(inputBuffer, inputAggBufferOffset))
494+
val bufferObject = field[T](buffer, mutableAggBufferOffset)
495+
val inputObject = deserialize(field[Array[Byte]](inputBuffer, inputAggBufferOffset))
524496
merge(bufferObject, inputObject)
525497
}
526498

527499
final override def eval(buffer: InternalRow): Any = {
528-
val bufferObject = field(buffer, mutableAggBufferOffset)
500+
val bufferObject = field[AnyRef](buffer, mutableAggBufferOffset)
529501
if (bufferObject.getClass == aggregationBufferClass) {
530502
// When used in Window frame aggregation, eval(buffer: InternalRow) is called directly
531503
// on the object aggregation buffer without intermediate serializing/de-serializing.
532504
eval(bufferObject.asInstanceOf[T])
533505
} else {
534-
eval(deserialize(bufferObject))
506+
eval(deserialize(bufferObject.asInstanceOf[Array[Byte]]))
535507
}
536508
}
537509

538-
private def field(input: InternalRow, offset: Int): AnyRef = {
539-
input.get(offset, null)
510+
private def field[U](input: InternalRow, fieldIndex: Int): U = {
511+
input.get(fieldIndex, null).asInstanceOf[U]
540512
}
541513

542514
final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
543515
// Underlying storage type for the aggregation buffer object
544-
Seq(AttributeReference("buf", aggregationBufferStorageFormatSqlType)())
516+
Seq(AttributeReference("buf", BinaryType)())
545517
}
546518

547519
final override lazy val inputAggBufferAttributes: Seq[AttributeReference] =

sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,27 @@
1717

1818
package org.apache.spark.sql
1919

20+
import com.google.common.primitives.Ints
21+
2022
import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
2123
import org.apache.spark.sql.catalyst.InternalRow
22-
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, UnsafeRow}
23-
import org.apache.spark.sql.catalyst.expressions.aggregate.{TypedImperativeAggregate}
24+
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericMutableRow, SpecificMutableRow, UnsafeRow}
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
2426
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
27+
import org.apache.spark.sql.expressions.Window
2528
import org.apache.spark.sql.functions._
2629
import org.apache.spark.sql.test.SharedSQLContext
27-
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType}
30+
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, IntegerType, LongType}
2831

2932
class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
3033

3134
import testImplicits._
3235

33-
private val data = Seq((1, 0), (3, 1), (2, 0), (6, 3), (3, 1), (4, 1), (5, 0))
36+
private val random = new java.util.Random()
3437

38+
private val data = (0 until 1000).map { _ =>
39+
(random.nextInt(10), random.nextInt(100))
40+
}
3541
test("aggregate with object aggregate buffer") {
3642
val agg = new TypedMax(BoundReference(0, IntegerType, nullable = false))
3743

@@ -55,37 +61,66 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
5561

5662
assert(mergeBuffer.value == data.map(_._1).max)
5763
assert(agg.eval(mergeBuffer) == data.map(_._1).max)
64+
65+
// Tests low level eval(row: InternalRow) API.
66+
val array: Array[Any] = Array(mergeBuffer)
67+
val row = new GenericMutableRow(array)
68+
69+
// Evaluates directly on row consist of aggregation buffer object.
70+
assert(agg.eval(row) == data.map(_._1).max)
71+
72+
// Serializes the aggregation buffer object and then evals.
73+
agg.serializeAggregateBufferInPlace(row)
74+
assert(agg.eval(row) == data.map(_._1).max)
75+
}
76+
77+
test("supports SpecificMutableRow as mutable row") {
78+
val aggregationBufferSchema = Seq(IntegerType, LongType, BinaryType, IntegerType)
79+
val aggBufferOffset = 2
80+
val inputBufferObject = 1
81+
val buffer = new SpecificMutableRow(aggregationBufferSchema)
82+
val agg = new TypedMax(BoundReference(inputBufferObject, IntegerType, nullable = false))
83+
.withNewMutableAggBufferOffset(aggBufferOffset)
84+
.withNewInputAggBufferOffset(inputBufferObject)
85+
86+
agg.initialize(buffer)
87+
data.foreach { kv =>
88+
val input = InternalRow(kv._1, kv._2)
89+
agg.update(buffer, input)
90+
}
91+
assert(agg.eval(buffer) == data.map(_._2).max)
5892
}
5993

6094
test("dataframe aggregate with object aggregate buffer, should not use HashAggregate") {
6195
val df = data.toDF("a", "b")
6296
val max = new TypedMax($"a".expr)
6397

64-
// Always use SortAggregateExec instead of HashAggregateExec for planning even if the aggregate
65-
// buffer attributes are mutable fields (every field can be mutated inline like int, long...)
66-
val allFieldsMutable = max.aggBufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)
98+
// Always uses SortAggregateExec
6799
val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan
68-
assert(allFieldsMutable == true && sparkPlan.isInstanceOf[SortAggregateExec])
100+
assert(sparkPlan.isInstanceOf[SortAggregateExec])
69101
}
70102

71103
test("dataframe aggregate with object aggregate buffer, no group by") {
72-
val df = data.toDF("a", "b").coalesce(2)
73-
checkAnswer(
74-
df.select(typedMax($"a"), count($"a"), typedMax($"b"), count($"b")),
75-
Seq(Row(6, 7, 3, 7))
76-
)
104+
val df = data.toDF("key", "value").coalesce(2)
105+
val query = df.select(typedMax($"key"), count($"key"), typedMax($"value"), count($"value"))
106+
val maxKey = data.map(_._1).max
107+
val countKey = data.size
108+
val maxValue = data.map(_._2).max
109+
val countValue = data.size
110+
val expected = Seq(Row(maxKey, countKey, maxValue, countValue))
111+
checkAnswer(query, expected)
77112
}
78113

79114
test("dataframe aggregate with object aggregate buffer, with group by") {
80-
val df = data.toDF("a", "b").coalesce(2)
81-
checkAnswer(
82-
df.groupBy($"b").agg(typedMax($"a"), count($"a"), typedMax($"a")),
83-
Seq(
84-
Row(0, 5, 3, 5),
85-
Row(1, 4, 3, 4),
86-
Row(3, 6, 1, 6)
87-
)
88-
)
115+
val df = data.toDF("value", "key").coalesce(2)
116+
val query = df.groupBy($"key").agg(typedMax($"value"), count($"value"), typedMax($"value"))
117+
val expected = data.groupBy(_._2).toSeq.map { group =>
118+
val (key, values) = group
119+
val valueMax = values.map(_._1).max
120+
val countValue = values.size
121+
Row(key, valueMax, countValue, valueMax)
122+
}
123+
checkAnswer(query, expected)
89124
}
90125

91126
test("dataframe aggregate with object aggregate buffer, empty inputs, no group by") {
@@ -102,6 +137,36 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
102137
Seq.empty[Row])
103138
}
104139

140+
test("TypedImperativeAggregate should not break Window function") {
141+
val df = data.toDF("key", "value")
142+
// OVER (PARTITION BY a ORDER BY b ROW BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
143+
val w = Window.orderBy("value").partitionBy("key").rowsBetween(Long.MinValue, 0)
144+
145+
val query = df.select(sum($"key").over(w), typedMax($"key").over(w), sum($"value").over(w),
146+
typedMax($"value").over(w))
147+
148+
val expected = data.groupBy(_._1).toSeq.flatMap { group =>
149+
val (key, values) = group
150+
val sortedValues = values.map(_._2).sorted
151+
152+
var outputRows = Seq.empty[Row]
153+
var i = 0
154+
while (i < sortedValues.size) {
155+
val unboundedPrecedingAndCurrent = sortedValues.slice(0, i + 1)
156+
val sumKey = key * unboundedPrecedingAndCurrent.size
157+
val maxKey = key
158+
val sumValue = unboundedPrecedingAndCurrent.sum
159+
val maxValue = unboundedPrecedingAndCurrent.max
160+
161+
outputRows :+= Row(sumKey, maxKey, sumValue, maxValue)
162+
i += 1
163+
}
164+
165+
outputRows
166+
}
167+
checkAnswer(query, expected)
168+
}
169+
105170
private def typedMax(column: Column): Column = {
106171
val max = TypedMax(column.expr)
107172
Column(max.toAggregateExpression())
@@ -159,14 +224,10 @@ object TypedImperativeAggregateSuite {
159224

160225
override def aggregationBufferClass: Class[MaxValue] = classOf[MaxValue]
161226

162-
override def serialize(buffer: MaxValue): Any = buffer.value
227+
override def serialize(buffer: MaxValue): Array[Byte] = Ints.toByteArray(buffer.value)
163228

164-
override def aggregationBufferStorageFormatSqlType: DataType = IntegerType
165-
166-
override def deserialize(storageFormat: Any): MaxValue = {
167-
storageFormat match {
168-
case i: Int => new MaxValue(i)
169-
}
229+
override def deserialize(storageFormat: Array[Byte]): MaxValue = {
230+
new MaxValue(Ints.fromByteArray(storageFormat))
170231
}
171232
}
172233

0 commit comments

Comments
 (0)