Skip to content

Commit 0b2da51

Browse files
committed
SPARK-8153 Add configuration for disabling partial aggregation in runtime
1 parent 7467b52 commit 0b2da51

File tree

9 files changed

+201
-10
lines changed

9 files changed

+201
-10
lines changed

core/src/main/scala/org/apache/spark/util/collection/BitSet.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class BitSet(numBits: Int) extends Serializable {
3232
*/
3333
def capacity: Int = numWords * 64
3434

35+
def clear(): Unit = {
36+
java.util.Arrays.fill(words, 0L)
37+
}
38+
3539
/**
3640
* Set all the bits up to a given index
3741
*/

core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
104104
/** Return true if this set contains the specified element. */
105105
def contains(k: T): Boolean = getPos(k) != INVALID_POS
106106

107+
def clear(): Unit = {
108+
_size = 0
109+
_bitset.clear()
110+
_capacity = nextPowerOf2(initialCapacity)
111+
_data = new Array[T](_capacity)
112+
}
113+
107114
/**
108115
* Add an element to the set. If the set is over capacity after the insertion, grow the set
109116
* and rehash all elements.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ trait AggregateExpression1 extends AggregateExpression {
4141
* of input rows/
4242
*/
4343
def newInstance(): AggregateFunction1
44+
45+
def reset(): AggregateFunction1 = newInstance()
4446
}
4547

4648
/**
@@ -81,6 +83,8 @@ abstract class AggregateFunction1 extends LeafExpression with Serializable {
8183

8284
def update(input: InternalRow): Unit
8385

86+
def reset(): AggregateFunction1
87+
8488
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
8589
throw new UnsupportedOperationException(
8690
"AggregateFunction1 should not be used for generated aggregates")
@@ -117,6 +121,11 @@ case class MinFunction(expr: Expression, base: AggregateExpression1) extends Agg
117121
}
118122
}
119123

124+
override def reset(): AggregateFunction1 = {
125+
currentMin.update(null)
126+
this
127+
}
128+
120129
override def eval(input: InternalRow): Any = currentMin.value
121130
}
122131

@@ -150,6 +159,11 @@ case class MaxFunction(expr: Expression, base: AggregateExpression1) extends Agg
150159
}
151160
}
152161

162+
override def reset(): AggregateFunction1 = {
163+
currentMax.update(null)
164+
this
165+
}
166+
153167
override def eval(input: InternalRow): Any = currentMax.value
154168
}
155169

@@ -178,6 +192,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression1) extends A
178192
}
179193
}
180194

195+
override def reset(): AggregateFunction1 = {
196+
count = 0
197+
this
198+
}
199+
181200
override def eval(input: InternalRow): Any = count
182201
}
183202

@@ -218,6 +237,11 @@ case class CountDistinctFunction(
218237
}
219238
}
220239

240+
override def reset(): AggregateFunction1 = {
241+
seen.clear()
242+
this
243+
}
244+
221245
override def eval(input: InternalRow): Any = seen.size.toLong
222246
}
223247

@@ -251,6 +275,11 @@ case class CollectHashSetFunction(
251275
}
252276
}
253277

278+
override def reset(): AggregateFunction1 = {
279+
seen.clear()
280+
this
281+
}
282+
254283
override def eval(input: InternalRow): Any = {
255284
seen
256285
}
@@ -285,6 +314,11 @@ case class CombineSetsAndCountFunction(
285314
}
286315
}
287316

317+
override def reset(): AggregateFunction1 = {
318+
seen.clear()
319+
this
320+
}
321+
288322
override def eval(input: InternalRow): Any = seen.size.toLong
289323
}
290324

@@ -333,6 +367,9 @@ case class ApproxCountDistinctPartitionFunction(
333367
}
334368

335369
override def eval(input: InternalRow): Any = hyperLogLog
370+
371+
override def reset(): AggregateFunction1 =
372+
new ApproxCountDistinctPartitionFunction(expr, base, relativeSD)
336373
}
337374

338375
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
@@ -361,6 +398,9 @@ case class ApproxCountDistinctMergeFunction(
361398
}
362399

363400
override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
401+
402+
override def reset(): AggregateFunction1 =
403+
new ApproxCountDistinctMergeFunction(expr, base, relativeSD)
364404
}
365405

366406
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
@@ -471,6 +511,12 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1)
471511
sum.update(addFunction(evaluatedExpr), input)
472512
}
473513
}
514+
515+
override def reset(): AggregateFunction1 = {
516+
count = 0
517+
sum.update(zero.eval(null))
518+
this
519+
}
474520
}
475521

476522
case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 {
@@ -528,6 +574,11 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
528574
sum.update(addFunction, input)
529575
}
530576

577+
override def reset(): AggregateFunction1 = {
578+
sum.update(null)
579+
this
580+
}
581+
531582
override def eval(input: InternalRow): Any = {
532583
expr.dataType match {
533584
case DecimalType.Fixed(_, _) =>
@@ -576,6 +627,11 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
576627
}
577628
}
578629

630+
override def reset(): AggregateFunction1 = {
631+
seen.clear()
632+
this
633+
}
634+
579635
override def eval(input: InternalRow): Any = {
580636
if (seen.size == 0) {
581637
null
@@ -617,6 +673,11 @@ case class CombineSetsAndSumFunction(
617673
}
618674
}
619675

676+
override def reset(): AggregateFunction1 = {
677+
seen.clear()
678+
this
679+
}
680+
620681
override def eval(input: InternalRow): Any = {
621682
val casted = seen.asInstanceOf[OpenHashSet[InternalRow]]
622683
if (casted.size == 0) {
@@ -656,6 +717,11 @@ case class FirstFunction(expr: Expression, base: AggregateExpression1) extends A
656717
}
657718
}
658719

720+
override def reset(): AggregateFunction1 = {
721+
result = null
722+
this
723+
}
724+
659725
override def eval(input: InternalRow): Any = result
660726
}
661727

@@ -687,6 +753,11 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
687753
}
688754
}
689755

756+
override def reset(): AggregateFunction1 = {
757+
result = null
758+
this
759+
}
760+
690761
override def eval(input: InternalRow): Any = {
691762
result
692763
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean
150150
value = expression.eval(input)
151151
}
152152

153+
def update(value: Any): Unit = {
154+
this.value = value
155+
}
156+
153157
override def eval(input: InternalRow): Any = value
154158
}

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,20 @@ private[spark] object SQLConf {
114114
}
115115
}, _.toString, doc, isPublic)
116116

117+
def floatConf(
118+
key: String,
119+
defaultValue: Option[Float] = None,
120+
doc: String = "",
121+
isPublic: Boolean = true): SQLConfEntry[Float] =
122+
SQLConfEntry(key, defaultValue, { v =>
123+
try {
124+
v.toFloat
125+
} catch {
126+
case _: NumberFormatException =>
127+
throw new IllegalArgumentException(s"$key should be float, but was $v")
128+
}
129+
}, _.toString, doc, isPublic)
130+
117131
def doubleConf(
118132
key: String,
119133
defaultValue: Option[Double] = None,
@@ -420,6 +434,17 @@ private[spark] object SQLConf {
420434
val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
421435
defaultValue = Some(true), doc = "<TODO>")
422436

437+
val PARTIAL_AGGREGATION_CHECK_INTERVAL = intConf(
438+
"spark.sql.partial.aggregation.checkInterval",
439+
defaultValue = Some(1000000),
440+
doc = "Number of input rows for checking aggregation status of hash")
441+
442+
val PARTIAL_AGGREGATION_MIN_REDUCTION = floatConf(
443+
"spark.sql.partial.aggregation.minReduction",
444+
defaultValue = Some(0.5f),
445+
doc = "Partial aggregation will be turned off when reduction ratio by hash is not " +
446+
"enough to this threshold")
447+
423448
object Deprecated {
424449
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
425450
}
@@ -527,6 +552,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
527552

528553
private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS)
529554

555+
private[spark] def partialAggregationCheckInterval: Int =
556+
getConf(PARTIAL_AGGREGATION_CHECK_INTERVAL)
557+
558+
private[spark] def partialAggregationMinReduction: Float =
559+
getConf(PARTIAL_AGGREGATION_MIN_REDUCTION)
560+
530561
/** ********************** SQLConf functionality methods ************ */
531562

532563
/** Set Spark SQL configuration properties. */

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.util.HashMap
20+
import java.util
2121

2222
import org.apache.spark.annotation.DeveloperApi
2323
import org.apache.spark.rdd.RDD
@@ -43,6 +43,8 @@ case class Aggregate(
4343
partial: Boolean,
4444
groupingExpressions: Seq[Expression],
4545
aggregateExpressions: Seq[NamedExpression],
46+
checkInterval: Int,
47+
minReduction: Float,
4648
child: SparkPlan)
4749
extends UnaryNode {
4850

@@ -155,16 +157,22 @@ case class Aggregate(
155157
}
156158
} else {
157159
child.execute().mapPartitions { iter =>
158-
val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]]
160+
val hashTable = new util.HashMap[InternalRow, Array[AggregateFunction1]]
159161
val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output)
160162

163+
var numInput: Int = 0
164+
var numOutput: Int = 0
165+
166+
var disabled: Boolean = false
161167
var currentRow: InternalRow = null
162-
while (iter.hasNext) {
168+
while (!disabled && iter.hasNext) {
169+
numInput += 1
163170
currentRow = iter.next()
164171
numInputRows += 1
165172
val currentGroup = groupingProjection(currentRow)
166173
var currentBuffer = hashTable.get(currentGroup)
167174
if (currentBuffer == null) {
175+
numOutput += 1
168176
currentBuffer = newAggregateBuffer()
169177
hashTable.put(currentGroup.copy(), currentBuffer)
170178
}
@@ -174,15 +182,22 @@ case class Aggregate(
174182
currentBuffer(i).update(currentRow)
175183
i += 1
176184
}
185+
if (partial && minReduction > 0 &&
186+
(numInput % checkInterval) == 0 && (numOutput / numInput > minReduction)) {
187+
log.info("Hash aggregation is disabled by insufficient reduction ratio " +
188+
(numOutput / numInput) + " which is expected to be less than " + minReduction)
189+
disabled = true
190+
}
177191
}
178192

179-
new Iterator[InternalRow] {
180-
private[this] val hashTableIter = hashTable.entrySet().iterator()
181-
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
182-
private[this] val resultProjection =
193+
val joinedRow = new JoinedRow
194+
val aggregateResults = new GenericMutableRow(computedAggregates.length)
195+
val resultProjection =
183196
new InterpretedMutableProjection(
184197
resultExpressions, computedSchema ++ namedGroups.map(_._2))
185-
private[this] val joinedRow = new JoinedRow
198+
199+
val result = new Iterator[InternalRow] {
200+
private[this] val hashTableIter = hashTable.entrySet().iterator()
186201

187202
override final def hasNext: Boolean = hashTableIter.hasNext
188203

@@ -202,6 +217,27 @@ case class Aggregate(
202217
resultProjection(joinedRow(aggregateResults, currentGroup))
203218
}
204219
}
220+
if (!iter.hasNext) {
221+
result
222+
} else {
223+
val currentBuffer = newAggregateBuffer()
224+
result ++ new Iterator[InternalRow] {
225+
override final def hasNext: Boolean = iter.hasNext
226+
227+
override final def next(): InternalRow = {
228+
currentRow = iter.next()
229+
val currentGroup = groupingProjection(currentRow)
230+
231+
var i = 0
232+
while (i < currentBuffer.length) {
233+
currentBuffer(i).reset().update(currentRow)
234+
aggregateResults(i) = currentBuffer(i).eval(EmptyRow)
235+
i += 1
236+
}
237+
resultProjection(joinedRow(aggregateResults, currentGroup))
238+
}
239+
}
240+
}
205241
}
206242
}
207243
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
2828
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
2929
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
30-
import org.apache.spark.sql.types._
3130
import org.apache.spark.sql.{SQLContext, Strategy, execution}
3231

3332
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
@@ -165,10 +164,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
165164
partial = false,
166165
namedGroupingAttributes,
167166
rewrittenAggregateExpressions,
167+
-1, -1,
168168
execution.Aggregate(
169169
partial = true,
170170
groupingExpressions,
171171
partialComputation,
172+
sqlContext.conf.partialAggregationCheckInterval,
173+
sqlContext.conf.partialAggregationMinReduction,
172174
planLater(child))) :: Nil
173175

174176
case _ => Nil
@@ -360,7 +362,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
360362
Nil
361363
} else {
362364
Utils.checkInvalidAggregateFunction2(a)
363-
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
365+
execution.Aggregate(partial = false, group, agg, -1, -1, planLater(child)) :: Nil
364366
}
365367
}
366368
case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>

0 commit comments

Comments
 (0)