Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class BitSet(numBits: Int) extends Serializable {
*/
def capacity: Int = numWords * 64

def clear(): Unit = {
java.util.Arrays.fill(words, 0L)
}

/**
* Set all the bits up to a given index
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
/** Return true if this set contains the specified element. */
def contains(k: T): Boolean = getPos(k) != INVALID_POS

def clear(): Unit = {
_size = 0
_bitset.clear()
_capacity = nextPowerOf2(initialCapacity)
_data = new Array[T](_capacity)
}

/**
* Add an element to the set. If the set is over capacity after the insertion, grow the set
* and rehash all elements.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ trait AggregateExpression1 extends AggregateExpression {
* of input rows/
*/
def newInstance(): AggregateFunction1

def reset(): AggregateFunction1 = newInstance()
}

/**
Expand Down Expand Up @@ -81,6 +83,8 @@ abstract class AggregateFunction1 extends LeafExpression with Serializable {

def update(input: InternalRow): Unit

def reset(): AggregateFunction1

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
throw new UnsupportedOperationException(
"AggregateFunction1 should not be used for generated aggregates")
Expand Down Expand Up @@ -117,6 +121,11 @@ case class MinFunction(expr: Expression, base: AggregateExpression1) extends Agg
}
}

override def reset(): AggregateFunction1 = {
currentMin.update(null)
this
}

override def eval(input: InternalRow): Any = currentMin.value
}

Expand Down Expand Up @@ -150,6 +159,11 @@ case class MaxFunction(expr: Expression, base: AggregateExpression1) extends Agg
}
}

override def reset(): AggregateFunction1 = {
currentMax.update(null)
this
}

override def eval(input: InternalRow): Any = currentMax.value
}

Expand Down Expand Up @@ -178,6 +192,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression1) extends A
}
}

override def reset(): AggregateFunction1 = {
count = 0
this
}

override def eval(input: InternalRow): Any = count
}

Expand Down Expand Up @@ -218,6 +237,11 @@ case class CountDistinctFunction(
}
}

override def reset(): AggregateFunction1 = {
seen.clear()
this
}

override def eval(input: InternalRow): Any = seen.size.toLong
}

Expand Down Expand Up @@ -251,6 +275,11 @@ case class CollectHashSetFunction(
}
}

override def reset(): AggregateFunction1 = {
seen.clear()
this
}

override def eval(input: InternalRow): Any = {
seen
}
Expand Down Expand Up @@ -285,6 +314,11 @@ case class CombineSetsAndCountFunction(
}
}

override def reset(): AggregateFunction1 = {
seen.clear()
this
}

override def eval(input: InternalRow): Any = seen.size.toLong
}

Expand Down Expand Up @@ -333,6 +367,9 @@ case class ApproxCountDistinctPartitionFunction(
}

override def eval(input: InternalRow): Any = hyperLogLog

override def reset(): AggregateFunction1 =
new ApproxCountDistinctPartitionFunction(expr, base, relativeSD)
}

case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
Expand Down Expand Up @@ -361,6 +398,9 @@ case class ApproxCountDistinctMergeFunction(
}

override def eval(input: InternalRow): Any = hyperLogLog.cardinality()

override def reset(): AggregateFunction1 =
new ApproxCountDistinctMergeFunction(expr, base, relativeSD)
}

case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
Expand Down Expand Up @@ -471,6 +511,12 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1)
sum.update(addFunction(evaluatedExpr), input)
}
}

override def reset(): AggregateFunction1 = {
count = 0
sum.update(zero.eval(null))
this
}
}

case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 {
Expand Down Expand Up @@ -528,6 +574,11 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
sum.update(addFunction, input)
}

override def reset(): AggregateFunction1 = {
sum.update(null)
this
}

override def eval(input: InternalRow): Any = {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Expand Down Expand Up @@ -576,6 +627,11 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
}
}

override def reset(): AggregateFunction1 = {
seen.clear()
this
}

override def eval(input: InternalRow): Any = {
if (seen.size == 0) {
null
Expand Down Expand Up @@ -617,6 +673,11 @@ case class CombineSetsAndSumFunction(
}
}

override def reset(): AggregateFunction1 = {
seen.clear()
this
}

override def eval(input: InternalRow): Any = {
val casted = seen.asInstanceOf[OpenHashSet[InternalRow]]
if (casted.size == 0) {
Expand Down Expand Up @@ -656,6 +717,11 @@ case class FirstFunction(expr: Expression, base: AggregateExpression1) extends A
}
}

override def reset(): AggregateFunction1 = {
result = null
this
}

override def eval(input: InternalRow): Any = result
}

Expand Down Expand Up @@ -687,6 +753,11 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
}
}

override def reset(): AggregateFunction1 = {
result = null
this
}

override def eval(input: InternalRow): Any = {
result
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean
value = expression.eval(input)
}

def update(value: Any): Unit = {
this.value = value
}

override def eval(input: InternalRow): Any = value
}
31 changes: 31 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@ private[spark] object SQLConf {
}
}, _.toString, doc, isPublic)

def floatConf(
key: String,
defaultValue: Option[Float] = None,
doc: String = "",
isPublic: Boolean = true): SQLConfEntry[Float] =
SQLConfEntry(key, defaultValue, { v =>
try {
v.toFloat
} catch {
case _: NumberFormatException =>
throw new IllegalArgumentException(s"$key should be float, but was $v")
}
}, _.toString, doc, isPublic)

def doubleConf(
key: String,
defaultValue: Option[Double] = None,
Expand Down Expand Up @@ -420,6 +434,17 @@ private[spark] object SQLConf {
val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
defaultValue = Some(true), doc = "<TODO>")

val PARTIAL_AGGREGATION_CHECK_INTERVAL = intConf(
"spark.sql.partial.aggregation.checkInterval",
defaultValue = Some(1000000),
doc = "Number of input rows for checking aggregation status of hash")

val PARTIAL_AGGREGATION_MIN_REDUCTION = floatConf(
"spark.sql.partial.aggregation.minReduction",
defaultValue = Some(0.5f),
doc = "Partial aggregation will be turned off when reduction ratio by hash is not " +
"enough to this threshold")

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
Expand Down Expand Up @@ -527,6 +552,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf {

private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS)

private[spark] def partialAggregationCheckInterval: Int =
getConf(PARTIAL_AGGREGATION_CHECK_INTERVAL)

private[spark] def partialAggregationMinReduction: Float =
getConf(PARTIAL_AGGREGATION_MIN_REDUCTION)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution

import java.util.HashMap
import java.util

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
Expand All @@ -43,6 +43,8 @@ case class Aggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
checkInterval: Int,
minReduction: Float,
child: SparkPlan)
extends UnaryNode {

Expand Down Expand Up @@ -155,16 +157,22 @@ case class Aggregate(
}
} else {
child.execute().mapPartitions { iter =>
val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]]
val hashTable = new util.HashMap[InternalRow, Array[AggregateFunction1]]
val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output)

var numInput: Int = 0
var numOutput: Int = 0

var disabled: Boolean = false
var currentRow: InternalRow = null
while (iter.hasNext) {
while (!disabled && iter.hasNext) {
numInput += 1
currentRow = iter.next()
numInputRows += 1
val currentGroup = groupingProjection(currentRow)
var currentBuffer = hashTable.get(currentGroup)
if (currentBuffer == null) {
numOutput += 1
currentBuffer = newAggregateBuffer()
hashTable.put(currentGroup.copy(), currentBuffer)
}
Expand All @@ -174,15 +182,22 @@ case class Aggregate(
currentBuffer(i).update(currentRow)
i += 1
}
if (partial && minReduction > 0 &&
(numInput % checkInterval) == 0 && (numOutput / numInput > minReduction)) {
log.info("Hash aggregation is disabled by insufficient reduction ratio " +
(numOutput / numInput) + " which is expected to be less than " + minReduction)
disabled = true
}
}

new Iterator[InternalRow] {
private[this] val hashTableIter = hashTable.entrySet().iterator()
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
private[this] val resultProjection =
val joinedRow = new JoinedRow
val aggregateResults = new GenericMutableRow(computedAggregates.length)
val resultProjection =
new InterpretedMutableProjection(
resultExpressions, computedSchema ++ namedGroups.map(_._2))
private[this] val joinedRow = new JoinedRow

val result = new Iterator[InternalRow] {
private[this] val hashTableIter = hashTable.entrySet().iterator()

override final def hasNext: Boolean = hashTableIter.hasNext

Expand All @@ -202,6 +217,27 @@ case class Aggregate(
resultProjection(joinedRow(aggregateResults, currentGroup))
}
}
if (!iter.hasNext) {
result
} else {
val currentBuffer = newAggregateBuffer()
result ++ new Iterator[InternalRow] {
override final def hasNext: Boolean = iter.hasNext

override final def next(): InternalRow = {
currentRow = iter.next()
val currentGroup = groupingProjection(currentRow)

var i = 0
while (i < currentBuffer.length) {
currentBuffer(i).reset().update(currentRow)
aggregateResults(i) = currentBuffer(i).eval(EmptyRow)
i += 1
}
resultProjection(joinedRow(aggregateResults, currentGroup))
}
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{SQLContext, Strategy, execution}

private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Expand Down Expand Up @@ -165,10 +164,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = false,
namedGroupingAttributes,
rewrittenAggregateExpressions,
-1, -1,
execution.Aggregate(
partial = true,
groupingExpressions,
partialComputation,
sqlContext.conf.partialAggregationCheckInterval,
sqlContext.conf.partialAggregationMinReduction,
planLater(child))) :: Nil

case _ => Nil
Expand Down Expand Up @@ -360,7 +362,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Nil
} else {
Utils.checkInvalidAggregateFunction2(a)
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
execution.Aggregate(partial = false, group, agg, -1, -1, planLater(child)) :: Nil
}
}
case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
Expand Down
Loading