diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 9c15b1188d91..d44a5cfca3e1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -32,6 +32,10 @@ class BitSet(numBits: Int) extends Serializable { */ def capacity: Int = numWords * 64 + def clear(): Unit = { + java.util.Arrays.fill(words, 0L) + } + /** * Set all the bits up to a given index */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 60bf4dd7469f..d39e3a4cce5f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -104,6 +104,13 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( /** Return true if this set contains the specified element. */ def contains(k: T): Boolean = getPos(k) != INVALID_POS + def clear(): Unit = { + _size = 0 + _bitset.clear() + _capacity = nextPowerOf2(initialCapacity) + _data = new Array[T](_capacity) + } + /** * Add an element to the set. If the set is over capacity after the insertion, grow the set * and rehash all elements. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5e8298aaaa9c..6ae6f3856c8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -41,6 +41,8 @@ trait AggregateExpression1 extends AggregateExpression { * of input rows/ */ def newInstance(): AggregateFunction1 + + def reset(): AggregateFunction1 = newInstance() } /** @@ -81,6 +83,8 @@ abstract class AggregateFunction1 extends LeafExpression with Serializable { def update(input: InternalRow): Unit + def reset(): AggregateFunction1 + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { throw new UnsupportedOperationException( "AggregateFunction1 should not be used for generated aggregates") @@ -117,6 +121,11 @@ case class MinFunction(expr: Expression, base: AggregateExpression1) extends Agg } } + override def reset(): AggregateFunction1 = { + currentMin.update(null) + this + } + override def eval(input: InternalRow): Any = currentMin.value } @@ -150,6 +159,11 @@ case class MaxFunction(expr: Expression, base: AggregateExpression1) extends Agg } } + override def reset(): AggregateFunction1 = { + currentMax.update(null) + this + } + override def eval(input: InternalRow): Any = currentMax.value } @@ -178,6 +192,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression1) extends A } } + override def reset(): AggregateFunction1 = { + count = 0 + this + } + override def eval(input: InternalRow): Any = count } @@ -218,6 +237,11 @@ case class CountDistinctFunction( } } + override def reset(): AggregateFunction1 = { + seen.clear() + this + } + override def eval(input: InternalRow): Any = seen.size.toLong } @@ -251,6 +275,11 @@ case class CollectHashSetFunction( } } + override def reset(): AggregateFunction1 = { + seen.clear() + this + } + override def eval(input: InternalRow): Any = { seen } @@ -285,6 +314,11 @@ case class CombineSetsAndCountFunction( } } + override def reset(): AggregateFunction1 = { + seen.clear() + this + } + override def eval(input: InternalRow): Any = seen.size.toLong } @@ -333,6 +367,9 @@ case class ApproxCountDistinctPartitionFunction( } override def eval(input: InternalRow): Any = hyperLogLog + + override def reset(): AggregateFunction1 = + new ApproxCountDistinctPartitionFunction(expr, base, relativeSD) } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) @@ -361,6 +398,9 @@ case class ApproxCountDistinctMergeFunction( } override def eval(input: InternalRow): Any = hyperLogLog.cardinality() + + override def reset(): AggregateFunction1 = + new ApproxCountDistinctMergeFunction(expr, base, relativeSD) } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) @@ -471,6 +511,12 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1) sum.update(addFunction(evaluatedExpr), input) } } + + override def reset(): AggregateFunction1 = { + count = 0 + sum.update(zero.eval(null)) + this + } } case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { @@ -528,6 +574,11 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg sum.update(addFunction, input) } + override def reset(): AggregateFunction1 = { + sum.update(null) + this + } + override def eval(input: InternalRow): Any = { expr.dataType match { case DecimalType.Fixed(_, _) => @@ -576,6 +627,11 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) } } + override def reset(): AggregateFunction1 = { + seen.clear() + this + } + override def eval(input: InternalRow): Any = { if (seen.size == 0) { null @@ -617,6 +673,11 @@ case class CombineSetsAndSumFunction( } } + override def reset(): AggregateFunction1 = { + seen.clear() + this + } + override def eval(input: InternalRow): Any = { val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] if (casted.size == 0) { @@ -656,6 +717,11 @@ case class FirstFunction(expr: Expression, base: AggregateExpression1) extends A } } + override def reset(): AggregateFunction1 = { + result = null + this + } + override def eval(input: InternalRow): Any = result } @@ -687,6 +753,11 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag } } + override def reset(): AggregateFunction1 = { + result = null + this + } + override def eval(input: InternalRow): Any = { result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 8c0c5d5b1e31..bf1bd84cdd91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -150,5 +150,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean value = expression.eval(input) } + def update(value: Any): Unit = { + this.value = value + } + override def eval(input: InternalRow): Any = value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9de75f4c4d08..ce412f909f2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -114,6 +114,20 @@ private[spark] object SQLConf { } }, _.toString, doc, isPublic) + def floatConf( + key: String, + defaultValue: Option[Float] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Float] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toFloat + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be float, but was $v") + } + }, _.toString, doc, isPublic) + def doubleConf( key: String, defaultValue: Option[Double] = None, @@ -420,6 +434,17 @@ private[spark] object SQLConf { val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", defaultValue = Some(true), doc = "") + val PARTIAL_AGGREGATION_CHECK_INTERVAL = intConf( + "spark.sql.partial.aggregation.checkInterval", + defaultValue = Some(1000000), + doc = "Number of input rows for checking aggregation status of hash") + + val PARTIAL_AGGREGATION_MIN_REDUCTION = floatConf( + "spark.sql.partial.aggregation.minReduction", + defaultValue = Some(0.5f), + doc = "Partial aggregation will be turned off when reduction ratio by hash is not " + + "enough to this threshold") + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -527,6 +552,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) + private[spark] def partialAggregationCheckInterval: Int = + getConf(PARTIAL_AGGREGATION_CHECK_INTERVAL) + + private[spark] def partialAggregationMinReduction: Float = + getConf(PARTIAL_AGGREGATION_MIN_REDUCTION) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index f3b6a3a5f4a3..aae45b1c27bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.util.HashMap +import java.util import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD @@ -43,6 +43,8 @@ case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], + checkInterval: Int, + minReduction: Float, child: SparkPlan) extends UnaryNode { @@ -155,16 +157,22 @@ case class Aggregate( } } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] + val hashTable = new util.HashMap[InternalRow, Array[AggregateFunction1]] val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) + var numInput: Int = 0 + var numOutput: Int = 0 + + var disabled: Boolean = false var currentRow: InternalRow = null - while (iter.hasNext) { + while (!disabled && iter.hasNext) { + numInput += 1 currentRow = iter.next() numInputRows += 1 val currentGroup = groupingProjection(currentRow) var currentBuffer = hashTable.get(currentGroup) if (currentBuffer == null) { + numOutput += 1 currentBuffer = newAggregateBuffer() hashTable.put(currentGroup.copy(), currentBuffer) } @@ -174,15 +182,22 @@ case class Aggregate( currentBuffer(i).update(currentRow) i += 1 } + if (partial && minReduction > 0 && + (numInput % checkInterval) == 0 && (numOutput / numInput > minReduction)) { + log.info("Hash aggregation is disabled by insufficient reduction ratio " + + (numOutput / numInput) + " which is expected to be less than " + minReduction) + disabled = true + } } - new Iterator[InternalRow] { - private[this] val hashTableIter = hashTable.entrySet().iterator() - private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) - private[this] val resultProjection = + val joinedRow = new JoinedRow + val aggregateResults = new GenericMutableRow(computedAggregates.length) + val resultProjection = new InterpretedMutableProjection( resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow + + val result = new Iterator[InternalRow] { + private[this] val hashTableIter = hashTable.entrySet().iterator() override final def hasNext: Boolean = hashTableIter.hasNext @@ -202,6 +217,27 @@ case class Aggregate( resultProjection(joinedRow(aggregateResults, currentGroup)) } } + if (!iter.hasNext) { + result + } else { + val currentBuffer = newAggregateBuffer() + result ++ new Iterator[InternalRow] { + override final def hasNext: Boolean = iter.hasNext + + override final def next(): InternalRow = { + currentRow = iter.next() + val currentGroup = groupingProjection(currentRow) + + var i = 0 + while (i < currentBuffer.length) { + currentBuffer(i).reset().update(currentRow) + aggregateResults(i) = currentBuffer(i).eval(EmptyRow) + i += 1 + } + resultProjection(joinedRow(aggregateResults, currentGroup)) + } + } + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4df53687a073..5477b0cfc81b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { @@ -165,10 +164,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = false, namedGroupingAttributes, rewrittenAggregateExpressions, + -1, -1, execution.Aggregate( partial = true, groupingExpressions, partialComputation, + sqlContext.conf.partialAggregationCheckInterval, + sqlContext.conf.partialAggregationMinReduction, planLater(child))) :: Nil case _ => Nil @@ -360,7 +362,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Nil } else { Utils.checkInvalidAggregateFunction2(a) - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + execution.Aggregate(partial = false, group, agg, -1, -1, planLater(child)) :: Nil } } case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index cad02373e5ba..14c1197df67b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -628,5 +628,10 @@ private[hive] case class HiveUDAFFunction( val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) } + + override def reset(): AggregateFunction1 = { + function.reset(buffer) + this + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 83f9f3eaa3a5..80d8222f3e14 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1087,6 +1087,37 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(getConf(testKey, "0") == "") } + test("SPARK-8153: Add configuration for disabling partial aggregation in runtime") { + val testData = sparkContext.parallelize( + TestData(1, "100") :: + TestData(1, "200") :: + TestData(1, "300") :: + TestData(2, "100") :: + TestData(2, "200") :: + TestData(3, "100") :: + TestData(3, "200") :: + TestData(3, "300") :: Nil) + testData.toDF().registerTempTable("test8153") + + val query: String = "select A, count(B) as c from test8153 group by A having c > 2 order by A" + + sql("set spark.sql.codegen=false") + sql("set spark.sql.unsafe.enabled=false") + sql("set spark.sql.partial.aggregation.checkInterval=1000000") + + val gold = sql(query).collect() + + // force discarding hash aggr (always 1.0) + sql("set spark.sql.partial.aggregation.checkInterval=1") + assertResult(gold){sql(query).collect()} + + sql("set spark.sql.codegen=true") + assertResult(gold){sql(query).collect()} + + sql("set spark.sql.unsafe.enabled=true") + assertResult(gold){sql(query).collect()} + } + test("SET commands semantics for a HiveContext") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly"