From 118edcf16334251cf07026d96be9d0ab43894843 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 22 Jul 2015 18:06:30 -0400 Subject: [PATCH 01/19] Added WindowFunctions. --- .../expressions/windowExpressions.scala | 188 +++++++++++++++++- .../apache/spark/sql/execution/Window.scala | 2 +- 2 files changed, 187 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 09ec0e333aa4..8fe9908e65e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{DataType, NumericType} +import org.apache.spark.sql.catalyst.expressions.aggregate.AlgebraicAggregate +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ /** * The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for @@ -305,7 +308,7 @@ case class UnresolvedWindowExpression( } case class WindowExpression( - windowFunction: WindowFunction, + windowFunction: Expression, windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil @@ -328,3 +331,184 @@ object FrameBoundaryExtractor { case _ => None } } + +/** + * A window function is a function that can only be evaluated in the context of a window operator. + */ +trait WindowFunction2 extends Expression { + /** + * Define the frame in which the window operator must be executed. + */ + def frame: WindowFrame = UnspecifiedFrame +} + +abstract class OffsetWindowFunction(child: Expression, offset: Int, default: Expression) + extends Expression with WindowFunction2 with CodegenFallback { + self: Product => + + override lazy val resolved = child.resolved && default.resolved && child.dataType == default.dataType + + override def children: Seq[Expression] = child :: default :: Nil + + override def dataType: DataType = child.dataType + + override def foldable: Boolean = child.foldable && default.foldable + + override def nullable: Boolean = child.nullable && default.nullable + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result != null) result + else default.eval(input) + } + + override def toString: String = s"${simpleString}($child, $offset, $default)" +} + +case class Lead(child: Expression, offset: Int, default: Expression) + extends OffsetWindowFunction(child, offset, default) { + def this(child: Expression, offset: Int) = + this(child, offset, Literal.create(null, child.dataType)) + + def this(child: Expression) = + this(child, 1, Literal.create(null, child.dataType)) + + override val frame = SpecifiedWindowFrame(RowFrame, + ValueFollowing(offset), + ValueFollowing(offset)) +} + +case class Lag(child: Expression, offset: Int, default: Expression) + extends OffsetWindowFunction(child, offset, default) { + def this(child: Expression, offset: Int) = + this(child, offset, Literal.create(null, child.dataType)) + + def this(child: Expression) = + this(child, 1, Literal.create(null, child.dataType)) + + override val frame = SpecifiedWindowFrame(RowFrame, + ValuePreceding(offset), + ValuePreceding(offset)) +} + +abstract class AggregateWindowFunction extends AlgebraicAggregate with WindowFunction2 { + self:Product => + override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) + override def dataType: DataType = IntegerType + override def foldable: Boolean = false + override def nullable: Boolean = false + def withContext(order: Seq[SortOrder], n: MutableLiteral): AggregateWindowFunction = this + override val mergeExpressions = Nil // TODO how to deal with this? +} + +abstract class RowNumberLike extends AggregateWindowFunction { + override def children: Seq[Expression] = Nil + override def deterministic: Boolean = false + override def inputTypes: Seq[AbstractDataType] = Nil + protected val rowNumber = AttributeReference("rowNumber", IntegerType)() + override val bufferAttributes: Seq[AttributeReference] = rowNumber :: Nil + override val initialValues: Seq[Expression] = Literal(0) :: Nil + override val updateExpressions: Seq[Expression] = rowNumber + 1 :: Nil +} + +case object RowNumber extends RowNumberLike { + override val evaluateExpression = Cast(rowNumber, IntegerType) +} + +// TODO check if this works in combination with CodeGeneration? +case class CumeDist(n: MutableLiteral) extends RowNumberLike { + def this() = this(MutableLiteral(0, IntegerType)) + override def dataType: DataType = DoubleType + override def deterministic: Boolean = true + override def withContext(order: Seq[SortOrder], n: MutableLiteral): CumeDist = CumeDist(n) + override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + override val evaluateExpression = Cast(rowNumber / n, DoubleType) +} + +// TODO check if this works in combination with CodeGeneration? +// TODO check logic +case class NTile(n: MutableLiteral, buckets: Int) extends RowNumberLike { + require(buckets > 0, "Number of buckets must be > 0") + def this(buckets: Int) = this(MutableLiteral(0, IntegerType), buckets) + override def withContext(order: Seq[SortOrder], n: MutableLiteral): NTile = NTile(n, buckets) + private val bucket = AttributeReference("bucket", IntegerType)() + private val bucketThreshold = AttributeReference("bucketThreshold", IntegerType)() + private val bucketSize = AttributeReference("bucketSize", IntegerType)() + private val bucketsWithPadding = AttributeReference("bucketsWithPadding", IntegerType)() + + override val bufferAttributes = Seq( + rowNumber, + bucket, + bucketThreshold, + bucketSize, + bucketsWithPadding + ) + + override val initialValues = Seq( + Literal(0), + Literal(0), + Literal(0), + Cast(n / buckets, IntegerType), + Cast(n % buckets, IntegerType) + ) + + override val updateExpressions = Seq( + rowNumber + 1, + bucket +If(rowNumber >= bucketThreshold, 1, 0), + bucketThreshold + + If(rowNumber >= bucketThreshold, bucketSize + If(bucket <= bucketsWithPadding, 1, 0), 0), + bucketSize, + bucketsWithPadding + ) + + override val evaluateExpression = bucket +} + +abstract class RankLike(order: Seq[SortOrder]) extends AggregateWindowFunction { + override def children: Seq[Expression] = order + override def deterministic: Boolean = true + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + protected val orderExprs = order.map(_.expr) + + protected val orderAttrs = orderExprs.zipWithIndex.map{ case (expr, i) => + AttributeReference(i.toString, expr.dataType)() + } + + protected val orderEquals = orderExprs.zip(orderAttrs).map(EqualNullSafe.tupled).reduce(And) + protected val orderInit = orderExprs.map(e => Literal.create(null, e.dataType)) + protected val rank = AttributeReference("rank", IntegerType)() + protected val rowNumber = AttributeReference("rowNumber", IntegerType)() + protected val updateRank = If(And(orderEquals, rank !== 0), rank, doUpdateRank) + + // Implementation for RANK() + protected val doUpdateRank: Expression = rowNumber + 1L + override val bufferAttributes = rank +: rowNumber +: orderAttrs + override val initialValues = Literal(0) +: Literal(0) +: orderInit + override val updateExpressions = doUpdateRank +: (rowNumber + 1) +: orderExprs + override val evaluateExpression: Expression = Cast(rank, LongType) + +} + +case class Rank(order: Seq[SortOrder]) extends RankLike(order) { + def this() = this(Nil) + override def withContext(order: Seq[SortOrder], n: MutableLiteral): Rank = Rank(order) +} + +case class DenseRank(order: Seq[SortOrder]) extends RankLike(order) { + def this() = this(Nil) + override def withContext(o: Seq[SortOrder], n: MutableLiteral): DenseRank = DenseRank(o) + override val bufferAttributes = rank +: orderAttrs + override val initialValues = Literal(0) +: orderInit + override val updateExpressions = doUpdateRank +: orderExprs + override val doUpdateRank = rank + 1 +} + +// TODO check if this works in combination with CodeGeneration? +case class PercentRank(order: Seq[SortOrder], n: MutableLiteral) extends RankLike(order) { + def this() = this(Nil, MutableLiteral(0, IntegerType)) + override def withContext(o: Seq[SortOrder], n: MutableLiteral): PercentRank = PercentRank(o, n) + override def dataType: DataType = DoubleType + override val evaluateExpression = + If(n > 1, Cast((rank - 1) / (n - 1), DoubleType), Literal.create(0.0d, DoubleType)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 0269d6d4b7a1..9e066712fcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -240,7 +240,7 @@ case class Window( // Bind the expressions. val functions = unboundFrameExpressions.map { e => - BindReferences.bindReference(e.windowFunction, child.output) + BindReferences.bindReference(e.windowFunction.asInstanceOf[WindowFunction], child.output) }.toArray // Create the frame processor factory. From 84401e79d6567004771df6340e40c8e9807d78ec Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 26 Jul 2015 00:33:47 -0400 Subject: [PATCH 02/19] Move to Native Spark UDAFs for window processing. --- .../sql/catalyst/analysis/Analyzer.scala | 115 +++- .../catalyst/analysis/FunctionRegistry.scala | 12 +- .../catalyst/expressions/aggregate/sets.scala | 108 ++++ .../expressions/aggregate/utils.scala | 114 ++-- .../expressions/windowExpressions.scala | 38 +- .../spark/sql/execution/SparkStrategies.scala | 11 +- .../apache/spark/sql/execution/Window.scala | 572 +++++++++++------- .../spark/sql/expressions/WindowSpec.scala | 32 +- .../org/apache/spark/sql/functions.scala | 16 +- .../org/apache/spark/sql/hive/HiveQl.scala | 23 +- 10 files changed, 695 insertions(+), 346 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a684dbc3afa4..98ab61c57648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -76,6 +76,7 @@ class Analyzer( ResolveGenerate :: ResolveFunctions :: ResolveAliases :: + ResolveWindowFrame :: ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: @@ -557,11 +558,18 @@ class Analyzer( } def containsAggregates(exprs: Seq[Expression]): Boolean = { - exprs.foreach(_.foreach { - case agg: AggregateExpression => return true - case _ => - }) - false + // Collect all Windowed Aggregate Expressions. + val blacklist = exprs.flatMap { expr => + expr.collect { + case WindowExpression(ae: AggregateExpression, _) => ae + } + }.toSet + + // Find the first Aggregate Expression that is not Windowed. + exprs.exists(_.collectFirst { + case ae: AggregateExpression if !blacklist.contains(ae) => ae + }.isDefined) + } } @@ -763,26 +771,38 @@ class Analyzer( // Now, we extract regular expressions from expressionsWithWindowFunctions // by using extractExpr. + val seenWindowAggregates = new ArrayBuffer[AggregateExpression] val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). case wf : WindowFunction => - val newChildren = wf.children.map(extractExpr(_)) + val newChildren = wf.children.map(extractExpr) + wf.withNewChildren(newChildren) + + case wf : WindowFunction2 => + val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) // Extracts expressions from the partition spec and order spec. case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => - val newPartitionSpec = partitionSpec.map(extractExpr(_)) + val newPartitionSpec = partitionSpec.map(extractExpr) val newOrderSpec = orderSpec.map { so => val newChild = extractExpr(so.child) so.copy(child = newChild) } wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) + // Extract Windowed AggregateExpression + case we @ WindowExpression(agg: AggregateExpression, spec: WindowSpecDefinition) => + val newAggChildren = agg.children.map(extractExpr) + val newAgg = agg.withNewChildren(newAggChildren) + seenWindowAggregates += newAgg + WindowExpression(newAgg, spec) + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). - case agg: AggregateExpression => + case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute @@ -957,6 +977,85 @@ class Analyzer( Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } + + /** + * Removes all still-need-evaluate ordering expressions from sort and use an inner project to + * materialize them, finally use a outer project to project them away to keep the result same. + * Then we can make sure we only sort by [[AttributeReference]]s. + * + * As an example, + * {{{ + * Sort('a, 'b + 1, + * Relation('a, 'b)) + * }}} + * will be turned into: + * {{{ + * Project('a, 'b, + * Sort('a, '_sortCondition, + * Project('a, 'b, ('b + 1).as("_sortCondition"), + * Relation('a, 'b)))) + * }}} + */ + object RemoveEvaluationFromSort extends Rule[LogicalPlan] { + private def hasAlias(expr: Expression) = { + expr.find { + case a: Alias => true + case _ => false + }.isDefined + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // The ordering expressions have no effect to the output schema of `Sort`, + // so `Alias`s in ordering expressions are unnecessary and we should remove them. + case s@Sort(ordering, _, _) if ordering.exists(hasAlias) => + val newOrdering = ordering.map(_.transformUp { + case Alias(child, _) => child + }.asInstanceOf[SortOrder]) + s.copy(order = newOrdering) + + case s@Sort(ordering, global, child) + if s.expressions.forall(_.resolved) && s.childrenResolved && !s.hasNoEvaluation => + + val (ref, needEval) = ordering.partition(_.child.isInstanceOf[AttributeReference]) + + val namedExpr = needEval.map(_.child match { + case n: NamedExpression => n + case e => Alias(e, "_sortCondition")() + }) + + val newOrdering = ref ++ needEval.zip(namedExpr).map { case (order, ne) => + order.copy(child = ne.toAttribute) + } + + // Add still-need-evaluate ordering expressions into inner project and then project + // them away after the sort. + Project(child.output, + Sort(newOrdering, global, + Project(child.output ++ namedExpr, child))) + } + } + + /* + * Check and add proper window frames for all window functions. + */ + object ResolveWindowFrame extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical.transformExpressionsDown { + case WindowExpression(wf: WindowFunction2, + WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + if wf.frame != UnspecifiedFrame && wf.frame != f => + failAnalysis(s"The frame of the window '$f' does not match the required frame " + + s"'${wf.frame}'") + case WindowExpression(wf: WindowFunction2, + s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + if wf.frame != UnspecifiedFrame => + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => + val frame = SpecifiedWindowFrame.defaultWindowFrame(!o.isEmpty, false) + we.copy(windowSpec = s.copy(frameSpecification = frame)) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cd5a90d78815..5f8c1704749b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -250,7 +250,17 @@ object FunctionRegistry { expression[Sha1]("sha1"), expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), - expression[InputFileName]("input_file_name") + expression[InputFileName]("input_file_name"), + + // window functions + expression[Lead]("lead"), + expression[Lag]("lag"), + expression[RowNumber]("row_number"), + expression[CumeDist]("cume_dist"), + expression[NTile]("ntile"), + expression[Rank]("rank"), + expression[DenseRank]("dense_rank"), + expression[PercentRank]("percent_rank") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala new file mode 100644 index 000000000000..5f8e4618f3b5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext, CodegenFallback} +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.collection.OpenHashSet + + +/** Reduce a set using an algebraic expression. */ +case class ReduceSetAlgebraic(left: Expression, right: AlgebraicAggregate) + extends BinaryExpression with CodegenFallback { + + override def dataType: DataType = right.dataType + + private[this] val single = right.children.size == 1 + private[this] val singleValueOrdinal = right.bufferSchema.length + + // This might be taking reuse too far... + @transient private[this] lazy val buffer = { + val singleSize = if (single) 1 else 0 + new GenericMutableRow(singleValueOrdinal + singleSize) + } + + @transient private[this] lazy val initial = + InterpretedMutableProjection(right.initialValues).target(buffer) + + @transient private[this] lazy val update = { + val schema = right.bufferAttributes ++ right.children.map { child => + AttributeReference("child", child.dataType, child.nullable)() + } + new InterpretedMutableProjection(right.updateExpressions, schema).target(buffer) + } + + @transient private[this] lazy val evaluate = + BindReferences.bindReference(right.evaluateExpression, right.bufferSchema.toAttributes) + + @transient private[this] lazy val joinRow = new JoinedRow4 + + override def eval(input: InternalRow): Any = { + val result = left.eval(input).asInstanceOf[OpenHashSet[Any]] + if (result != null) { + initial(EmptyRow) + val iterator = result.iterator + // Prevent branch during iteration. + if (single) { + while (iterator.hasNext) { + buffer.update(singleValueOrdinal, iterator.next) + update(buffer) + } + } else { + while (iterator.hasNext) { + joinRow(buffer, iterator.next.asInstanceOf[InternalRow]) + update(joinRow) + } + } + evaluate.eval(buffer) + } else null + } +} +/** Reduce a set using an AggregateFunction2. */ +case class ReduceSetAggregate(left: Expression, right: AggregateFunction2) + extends BinaryExpression with CodegenFallback { + + right.bufferOffset = 0 + + override def dataType: DataType = right.dataType + + private[this] val single = right.children.size == 1 + @transient private[this] lazy val buffer = new GenericMutableRow(right.bufferSchema.size) + @transient private[this] lazy val singleValueInput = new GenericMutableRow(1) + + override def eval(input: InternalRow): Any = { + val result = left.eval(input).asInstanceOf[OpenHashSet[Any]] + if (result != null) { + right.initialize(buffer) + val iterator = result.iterator + if (single) { + while (iterator.hasNext) { + singleValueInput.update(0, iterator.next()) + right.update(buffer, singleValueInput) + } + } else { + while (iterator.hasNext) { + right.update(buffer, iterator.next().asInstanceOf[InternalRow]) + } + } + right.eval(buffer) + } else null + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a9549..37d96eccdd31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -39,64 +39,66 @@ object Utils { !hasComplexTypes } + val convertAggregateExpressions: PartialFunction[Expression, Expression] = { + case expressions.Average(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Average(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Count(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(child), + mode = aggregate.Complete, + isDistinct = false) + + // We do not support multiple COUNT DISTINCT columns for now. + case expressions.CountDistinct(children) if children.length == 1 => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(children.head), + mode = aggregate.Complete, + isDistinct = true) + + case expressions.First(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.First(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Last(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Last(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Max(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Max(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Min(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Min(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Sum(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.SumDistinct(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = true) + } private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - } + val converted = p.transformExpressionsDown(convertAggregateExpressions) + // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 8fe9908e65e7..4a80100b9cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -298,7 +298,7 @@ case class UnresolvedWindowFunction( } case class UnresolvedWindowExpression( - child: UnresolvedWindowFunction, + child: Expression, windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") @@ -342,11 +342,16 @@ trait WindowFunction2 extends Expression { def frame: WindowFrame = UnspecifiedFrame } +trait SizeBasedWindowFunction extends WindowFunction2 { + def withSize(n: MutableLiteral): SizeBasedWindowFunction +} + abstract class OffsetWindowFunction(child: Expression, offset: Int, default: Expression) extends Expression with WindowFunction2 with CodegenFallback { self: Product => - override lazy val resolved = child.resolved && default.resolved && child.dataType == default.dataType + override lazy val resolved = + child.resolved && default.resolved && child.dataType == default.dataType override def children: Seq[Expression] = child :: default :: Nil @@ -362,7 +367,7 @@ abstract class OffsetWindowFunction(child: Expression, offset: Int, default: Exp else default.eval(input) } - override def toString: String = s"${simpleString}($child, $offset, $default)" + override def toString: String = s"$simpleString($child, $offset, $default)" } case class Lead(child: Expression, offset: Int, default: Expression) @@ -392,12 +397,11 @@ case class Lag(child: Expression, offset: Int, default: Expression) } abstract class AggregateWindowFunction extends AlgebraicAggregate with WindowFunction2 { - self:Product => + self: Product => override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) override def dataType: DataType = IntegerType override def foldable: Boolean = false override def nullable: Boolean = false - def withContext(order: Seq[SortOrder], n: MutableLiteral): AggregateWindowFunction = this override val mergeExpressions = Nil // TODO how to deal with this? } @@ -411,26 +415,27 @@ abstract class RowNumberLike extends AggregateWindowFunction { override val updateExpressions: Seq[Expression] = rowNumber + 1 :: Nil } -case object RowNumber extends RowNumberLike { +case class RowNumber() extends RowNumberLike { override val evaluateExpression = Cast(rowNumber, IntegerType) } // TODO check if this works in combination with CodeGeneration? -case class CumeDist(n: MutableLiteral) extends RowNumberLike { +case class CumeDist(n: MutableLiteral) extends RowNumberLike with SizeBasedWindowFunction { def this() = this(MutableLiteral(0, IntegerType)) override def dataType: DataType = DoubleType override def deterministic: Boolean = true - override def withContext(order: Seq[SortOrder], n: MutableLiteral): CumeDist = CumeDist(n) + override def withSize(n: MutableLiteral): CumeDist = CumeDist(n) override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) override val evaluateExpression = Cast(rowNumber / n, DoubleType) } // TODO check if this works in combination with CodeGeneration? // TODO check logic -case class NTile(n: MutableLiteral, buckets: Int) extends RowNumberLike { +case class NTile(n: MutableLiteral, buckets: Int) extends RowNumberLike + with SizeBasedWindowFunction { require(buckets > 0, "Number of buckets must be > 0") def this(buckets: Int) = this(MutableLiteral(0, IntegerType), buckets) - override def withContext(order: Seq[SortOrder], n: MutableLiteral): NTile = NTile(n, buckets) + override def withSize(n: MutableLiteral): NTile = NTile(n, buckets) private val bucket = AttributeReference("bucket", IntegerType)() private val bucketThreshold = AttributeReference("bucketThreshold", IntegerType)() private val bucketSize = AttributeReference("bucketSize", IntegerType)() @@ -454,7 +459,7 @@ case class NTile(n: MutableLiteral, buckets: Int) extends RowNumberLike { override val updateExpressions = Seq( rowNumber + 1, - bucket +If(rowNumber >= bucketThreshold, 1, 0), + bucket + If(rowNumber >= bucketThreshold, 1, 0), bucketThreshold + If(rowNumber >= bucketThreshold, bucketSize + If(bucket <= bucketsWithPadding, 1, 0), 0), bucketSize, @@ -488,16 +493,17 @@ abstract class RankLike(order: Seq[SortOrder]) extends AggregateWindowFunction { override val updateExpressions = doUpdateRank +: (rowNumber + 1) +: orderExprs override val evaluateExpression: Expression = Cast(rank, LongType) + def withOrder(order: Seq[SortOrder]): RankLike } case class Rank(order: Seq[SortOrder]) extends RankLike(order) { def this() = this(Nil) - override def withContext(order: Seq[SortOrder], n: MutableLiteral): Rank = Rank(order) + override def withOrder(order: Seq[SortOrder]): Rank = Rank(order) } case class DenseRank(order: Seq[SortOrder]) extends RankLike(order) { def this() = this(Nil) - override def withContext(o: Seq[SortOrder], n: MutableLiteral): DenseRank = DenseRank(o) + override def withOrder(order: Seq[SortOrder]): DenseRank = DenseRank(order) override val bufferAttributes = rank +: orderAttrs override val initialValues = Literal(0) +: orderInit override val updateExpressions = doUpdateRank +: orderExprs @@ -505,9 +511,11 @@ case class DenseRank(order: Seq[SortOrder]) extends RankLike(order) { } // TODO check if this works in combination with CodeGeneration? -case class PercentRank(order: Seq[SortOrder], n: MutableLiteral) extends RankLike(order) { +case class PercentRank(order: Seq[SortOrder], n: MutableLiteral) extends RankLike(order) + with SizeBasedWindowFunction { def this() = this(Nil, MutableLiteral(0, IntegerType)) - override def withContext(o: Seq[SortOrder], n: MutableLiteral): PercentRank = PercentRank(o, n) + override def withOrder(order: Seq[SortOrder]): PercentRank = PercentRank(order, n) + override def withSize(n: MutableLiteral): PercentRank = PercentRank(order, n) override def dataType: DataType = DoubleType override val evaluateExpression = If(n > 1, Cast((rank - 1) / (n - 1), DoubleType), Literal.create(0.0d, DoubleType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4b9b5acea4d..8da3c35c3298 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -344,9 +344,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } + case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + val convertedWindowExpressions = windowExprs.map { e => + val converted = e.transformDown(Utils.convertAggregateExpressions) + converted.asInstanceOf[NamedExpression] + } execution.Window( - projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + projectList, + convertedWindowExpressions, + partitionSpec, + orderSpec, + planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 9e066712fcce..834cc8810004 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -22,8 +22,9 @@ import java.util import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{StructType, NullType, IntegerType} import org.apache.spark.rdd.RDD import org.apache.spark.util.collection.CompactBuffer import scala.collection.mutable @@ -33,7 +34,7 @@ import scala.collection.mutable * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) * partition. The aggregates are calculated for each row in the group. Special processing * instructions, frames, are used to calculate these aggregates. Frames are processed in the order - * specified in the window specification (the ORDER BY ... clause). There are four different frame + * specified in the window specification (the ORDER BY ... clause). There are five different frame * types: * - Entire partition: The frame is the entire partition, i.e. * UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. For this case, window function will take all @@ -47,6 +48,8 @@ import scala.collection.mutable * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame * and we add some rows to the frame. Examples are: * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * - Offset frame: The frame consist of one row, which is an offset number of rows away from the + * current row. Only non-aggregate expressions can be evaluated in a offset frame. * * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame * boundary can be either Row or Range based: @@ -167,32 +170,54 @@ case class Window( */ private[this] def createFrameProcessor( frame: WindowFrame, - functions: Array[WindowFunction], - ordinal: Int): WindowFunctionFrame = frame match { - // Growing Frame. - case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => - val uBoundOrdering = createBoundOrdering(frameType, high) - new UnboundedPrecedingWindowFunctionFrame(ordinal, functions, uBoundOrdering) - - // Shrinking Frame. - case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => - val lBoundOrdering = createBoundOrdering(frameType, low) - new UnboundedFollowingWindowFunctionFrame(ordinal, functions, lBoundOrdering) - - // Moving Frame. - case SpecifiedWindowFrame(frameType, - FrameBoundaryExtractor(low), FrameBoundaryExtractor(high)) => - val lBoundOrdering = createBoundOrdering(frameType, low) - val uBoundOrdering = createBoundOrdering(frameType, high) - new SlidingWindowFunctionFrame(ordinal, functions, lBoundOrdering, uBoundOrdering) - - // Entire Partition Frame. - case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => - new UnboundedWindowFunctionFrame(ordinal, functions) - - // Error - case fr => - sys.error(s"Unsupported Frame $fr for functions: $functions") + functions: Array[Expression], + ordinal: Int, + result: MutableRow, + size: MutableLiteral): WindowFunctionFrame = { + // Construct the target row. + val target = if (ordinal == 0) result + else new OffsetMutableRow(ordinal, result) + + // Construct an aggregate processor if we have to. + def processor = { + val prepared = functions.map { + case f: SizeBasedWindowFunction => f.withSize(size) + case f => f + } + AggregateProcessor(prepared, child.output, newMutableProjection) + } + + // Create the frame processor. + frame match { + // Offset Frame + case SpecifiedWindowFrame(RowFrame, FrameBoundaryExtractor(l), FrameBoundaryExtractor(h)) + if l == h => + new OffsetWindowFunctionFrame(target, functions, child.output, newMutableProjection, l) + + // Growing Frame. + case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => + val uBoundOrdering = createBoundOrdering(frameType, high) + new UnboundedPrecedingWindowFunctionFrame(target, processor, uBoundOrdering) + + // Shrinking Frame. + case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => + val lBoundOrdering = createBoundOrdering(frameType, low) + new UnboundedFollowingWindowFunctionFrame(target, processor, lBoundOrdering) + + // Moving Frame. + case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(l), FrameBoundaryExtractor(h)) => + val lBoundOrdering = createBoundOrdering(frameType, l) + val uBoundOrdering = createBoundOrdering(frameType, h) + new SlidingWindowFunctionFrame(target, processor, lBoundOrdering, uBoundOrdering) + + // Entire Partition Frame. + case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => + new UnboundedWindowFunctionFrame(target, processor) + + // Error + case fr => + sys.error(s"Unsupported Frame $fr for functions: $functions") + } } /** @@ -228,7 +253,8 @@ case class Window( // are processed in; this is the order in which their results will be written to window // function result buffer. val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) - val factories = Array.ofDim[() => WindowFunctionFrame](framedWindowExprs.size) + val factories = Array.ofDim[(MutableRow, MutableLiteral) => + WindowFunctionFrame](framedWindowExprs.size) val unboundExpressions = mutable.Buffer.empty[Expression] framedWindowExprs.zipWithIndex.foreach { case ((frame, unboundFrameExpressions), index) => @@ -240,11 +266,18 @@ case class Window( // Bind the expressions. val functions = unboundFrameExpressions.map { e => - BindReferences.bindReference(e.windowFunction.asInstanceOf[WindowFunction], child.output) + // Perhaps move code below to analyser. The dependency used in the pattern match might + // be to narrow (only RankLike and its subclasses). + val function = e.windowFunction match { + case r: RankLike => r.withOrder(windowSpec.orderSpec) + case f => f + } + BindReferences.bindReference(function, child.output) }.toArray // Create the frame processor factory. - factories(index) = () => createFrameProcessor(frame, functions, ordinal) + factories(index) = (result: MutableRow, size: MutableLiteral) => + createFrameProcessor(frame, functions, ordinal, result, size) } // Start processing. @@ -273,7 +306,11 @@ case class Window( // Manage the current partition. var rows: CompactBuffer[InternalRow] = _ - val frames: Array[WindowFunctionFrame] = factories.map(_()) + val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) + val partitionSize = MutableLiteral(0, IntegerType, nullable = false) + val frames: Array[WindowFunctionFrame] = factories.map{ f => + f(windowFunctionResult, partitionSize) + } val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. @@ -284,6 +321,9 @@ case class Window( fetchNextRow() } + // Propagate partition size. + partitionSize.value = rows.size + // Setup the frames. var i = 0 while (i < numFrames) { @@ -302,7 +342,6 @@ case class Window( override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable val join = new JoinedRow - val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) override final def next(): InternalRow = { // Load the next partition if we need to. if (rowIndex >= rowsSize && nextRowAvailable) { @@ -313,7 +352,7 @@ case class Window( // Get the results for the window frames. var i = 0 while (i < numFrames) { - frames(i).write(windowFunctionResult) + frames(i).write() i += 1 } @@ -360,48 +399,8 @@ private[execution] final case class RangeBoundOrdering( * A window function calculates the results of a number of window functions for a window frame. * Before use a frame must be prepared by passing it all the rows in the current partition. After * preparation the update method can be called to fill the output rows. - * - * TODO How to improve performance? A few thoughts: - * - Window functions are expensive due to its distribution and ordering requirements. - * Unfortunately it is up to the Spark engine to solve this. Improvements in the form of project - * Tungsten are on the way. - * - The window frame processing bit can be improved though. But before we start doing that we - * need to see how much of the time and resources are spent on partitioning and ordering, and - * how much time and resources are spent processing the partitions. There are a couple ways to - * improve on the current situation: - * - Reduce memory footprint by performing streaming calculations. This can only be done when - * there are no Unbound/Unbounded Following calculations present. - * - Use Tungsten style memory usage. - * - Use code generation in general, and use the approach to aggregation taken in the - * GeneratedAggregate class in specific. - * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. */ -private[execution] abstract class WindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) { - - // Make sure functions are initialized. - functions.foreach(_.init()) - - /** Number of columns the window function frame is managing */ - val numColumns = functions.length - - /** - * Create a fresh thread safe copy of the frame. - * - * @return the copied frame. - */ - def copy: WindowFunctionFrame - - /** - * Create new instances of the functions. - * - * @return an array containing copies of the current window functions. - */ - protected final def copyFunctions: Array[WindowFunction] = functions.map(_.newInstance()) - +private[execution] abstract class WindowFunctionFrame { /** * Prepare the frame for calculating the results for a partition. * @@ -410,90 +409,60 @@ private[execution] abstract class WindowFunctionFrame( def prepare(rows: CompactBuffer[InternalRow]): Unit /** - * Write the result for the current row to the given target row. - * - * @param target row to write the result for the current row to. + * Write the current results to the target row. */ - def write(target: GenericMutableRow): Unit + def write(): Unit +} - /** Reset the current window functions. */ - protected final def reset(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).reset() - i += 1 - } - } +/** + * The offset window frame calculates frames containing LEAD/LAG statements. + * + * @param target to write results to. + * @param expressions to shift a number of rows. + * @param inputSchema required for creating a projection. + * @param newMutableProjection function used to create the projection. + * @param offset by which rows get moved within a partition. + */ +private[execution] final class OffsetWindowFunctionFrame( + target: MutableRow, + expressions: Array[Expression], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, + offset: Int) extends WindowFunctionFrame { - /** Prepare an input row for processing. */ - protected final def prepare(input: InternalRow): Array[AnyRef] = { - val prepared = new Array[AnyRef](numColumns) - var i = 0 - while (i < numColumns) { - prepared(i) = functions(i).prepareInputParameters(input) - i += 1 - } - prepared - } + /** Rows of the partition currently being processed. */ + private[this] var input: CompactBuffer[InternalRow] = null - /** Evaluate a prepared buffer (iterator). */ - protected final def evaluatePrepared(iterator: java.util.Iterator[Array[AnyRef]]): Unit = { - reset() - while (iterator.hasNext) { - val prepared = iterator.next() - var i = 0 - while (i < numColumns) { - functions(i).update(prepared(i)) - i += 1 - } - } - evaluate() - } + /** Index of the row we are currently using for output. */ + private[this] var inputIndex = 0 - /** Evaluate a prepared buffer (array). */ - protected final def evaluatePrepared(prepared: Array[Array[AnyRef]], - fromIndex: Int, toIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - val function = functions(i) - function.reset() - var j = fromIndex - while (j < toIndex) { - function.update(prepared(j)(i)) - j += 1 - } - function.evaluate() - i += 1 - } - } + /** Check if the output has been explicitly cleared. */ + private[this] var outputNull = false - /** Update an array of window functions. */ - protected final def update(input: InternalRow): Unit = { - var i = 0 - while (i < numColumns) { - val aggregate = functions(i) - val preparedInput = aggregate.prepareInputParameters(input) - aggregate.update(preparedInput) - i += 1 - } - } + /** Create a */ + private[this] val projection = newMutableProjection(expressions.toSeq, inputSchema)() + projection.target(target) - /** Evaluate the window functions. */ - protected final def evaluate(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).evaluate() - i += 1 - } + override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + input = rows + inputIndex = offset } - /** Fill a target row with the current window function results. */ - protected final def fill(target: GenericMutableRow, rowIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - target.update(ordinal + i, functions(i).get(rowIndex)) - i += 1 + override def write(): Unit = { + val size = input.size + if (inputIndex >= 0 && inputIndex < size) { + projection(input(inputIndex)) + outputNull = false } + else if (!outputNull) { + var i = 0 + while (i < expressions.length) { + target.setNullAt(i) + i += 1 + } + outputNull = true + } + inputIndex += 1 } } @@ -501,16 +470,16 @@ private[execution] abstract class WindowFunctionFrame( * The sliding window frame calculates frames with the following SQL form: * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class SlidingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], + target: MutableRow, + processor: AggregateProcessor, lbound: BoundOrdering, - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ private[this] var input: CompactBuffer[InternalRow] = null @@ -524,7 +493,7 @@ private[execution] final class SlidingWindowFunctionFrame( private[this] var inputLowIndex = 0 /** Buffer used for storing prepared input for the window functions. */ - private[this] val buffer = new util.ArrayDeque[Array[AnyRef]] + private[this] val buffer = new util.ArrayDeque[InternalRow] /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 @@ -539,14 +508,14 @@ private[execution] final class SlidingWindowFunctionFrame( } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Add all rows to the buffer for which the input row value is equal to or less than // the output row upper bound. while (inputHighIndex < input.size && ubound.compare(input, inputHighIndex, outputIndex) <= 0) { - buffer.offer(prepare(input(inputHighIndex))) + buffer.offer(input(inputHighIndex)) inputHighIndex += 1 bufferUpdated = true } @@ -562,17 +531,17 @@ private[execution] final class SlidingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer.iterator()) - fill(target, outputIndex) + val iterator = buffer.iterator() + val status = processor.initialize + while (iterator.hasNext) { + processor.update(status, iterator.next()) + } + processor.evaluate(target, status) } // Move to the next row. outputIndex += 1 } - - /** Copy the frame. */ - override def copy: SlidingWindowFunctionFrame = - new SlidingWindowFunctionFrame(ordinal, copyFunctions, lbound, ubound) } /** @@ -583,36 +552,30 @@ private[execution] final class SlidingWindowFunctionFrame( * Its results are the same for each and every row in the partition. This class can be seen as a * special case of a sliding window, but is optimized for the unbound case. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. */ private[execution] final class UnboundedWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) extends WindowFunctionFrame(ordinal, functions) { + target: MutableRow, + processor: AggregateProcessor) extends WindowFunctionFrame { - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 + /** The collected aggregate status of all rows in the input. */ + private[this] var status: MutableRow = _ /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() - outputIndex = 0 + status = processor.initialize val iterator = rows.iterator while (iterator.hasNext) { - update(iterator.next()) + processor.update(status, iterator.next()) } - evaluate() } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - fill(target, outputIndex) - outputIndex += 1 + override def write(): Unit = { + // Unfortunately we cannot assume that evaluation is deterministic. + processor.evaluate(target, status) } - - /** Copy the frame. */ - override def copy: UnboundedWindowFunctionFrame = - new UnboundedWindowFunctionFrame(ordinal, copyFunctions) } /** @@ -625,14 +588,14 @@ private[execution] final class UnboundedWindowFunctionFrame( * is not the case when there is no lower bound, given the additive nature of most aggregates * streaming updates and partial evaluation suffice and no buffering is needed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class UnboundedPrecedingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + target: MutableRow, + processor: AggregateProcessor, + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ private[this] var input: CompactBuffer[InternalRow] = null @@ -644,39 +607,37 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 + /** The collected aggregate status of all rows seen so far. */ + private[this] var status: MutableRow = _ + /** Prepare the frame for calculating a new partition. */ override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() input = rows inputIndex = 0 outputIndex = 0 + status = processor.initialize } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Add all rows to the aggregates for which the input row value is equal to or less than // the output row upper bound. while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) { - update(input(inputIndex)) + processor.update(status, input(inputIndex)) inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluate() - fill(target, outputIndex) + processor.evaluate(target, status) } // Move to the next row. outputIndex += 1 } - - /** Copy the frame. */ - override def copy: UnboundedPrecedingWindowFunctionFrame = - new UnboundedPrecedingWindowFunctionFrame(ordinal, copyFunctions, ubound) } /** @@ -691,17 +652,14 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( * buffer and must do full recalculation after each row. Reverse iteration would be possible, if * the communitativity of the used window functions can be guaranteed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. */ private[execution] final class UnboundedFollowingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - lbound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { - - /** Buffer used for storing prepared input for the window functions. */ - private[this] var buffer: Array[Array[AnyRef]] = _ + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ private[this] var input: CompactBuffer[InternalRow] = null @@ -718,18 +676,10 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( input = rows inputIndex = 0 outputIndex = 0 - val size = input.size - buffer = Array.ofDim(size) - var i = 0 - while (i < size) { - buffer(i) = prepare(input(i)) - i += 1 - } - evaluatePrepared(buffer, 0, buffer.length) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Drop all rows from the buffer for which the input row value is smaller than @@ -741,15 +691,205 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer, inputIndex, buffer.length) - fill(target, outputIndex) + var i = inputIndex + val size = input.size + val status = processor.initialize + while (i < size) { + processor.update(status, input(i)) + i += 1 + } + processor.evaluate(target, status) } // Move to the next row. outputIndex += 1 } +} + +/** + * This class prepares and manages the processing of a number of aggregate functions. + * + * The following aggregates are supported: + * [[AggregateExpression1]] + * [[AggregateExpression2]] + * [[AggregateFunction2]] + * [[AlgebraicAggregate]] + * + * Note that the [[AggregateExpression1]] code path will probably be removed in SPARK 1.6.0. + * + * The current implementation only supports evaluation in [[Complete]] mode. This is enough for + * Window processing. Adding other processing modes is dependent on the support of + * [[AggregateExpression1]]. + * + * Processing of any number of distinct aggregates is supported using Set operations. More + * advanced distinct operators (e.g. Sort Based Operators) should be added before the + * [[AggregateProcessor]] is created. + * + * The implementation is split into an object which takes care of construction, and a the actual + * processor class. Construction might be expensive and could be separated into a 'driver' and a + * 'executor' part. + */ +private[execution] object AggregateProcessor { + def apply(functions: Array[Expression], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): + AggregateProcessor = { + val bufferSchema = mutable.Buffer.empty[AttributeReference] + val initialValues = mutable.Buffer.empty[Expression] + val updateExpressions = mutable.Buffer.empty[Expression] + val evaluateExpressions = mutable.Buffer.empty[Expression] + val aggregates1 = mutable.Buffer.empty[AggregateExpression1] + val aggregates1BufferOffsets = mutable.Buffer.empty[Int] + val aggregates1OutputOffsets = mutable.Buffer.empty[Int] + val aggregates2 = mutable.Buffer.empty[AggregateFunction2] + val aggregates2OutputOffsets = mutable.Buffer.empty[Int] + + // Flatten AggregateExpression2's + val flattened = functions.zipWithIndex.map { + case (AggregateExpression2(af2, _, distinct), i) => (af2, distinct, i) + case (e, i) => (e, false, i) + } + + // Add distinct evaluation path. + val distinctExpressionSchemaMap = mutable.HashMap.empty[Seq[Expression], AttributeReference] + flattened.filter(_._2).foreach { + case (af2, _, _) => + // TODO cannocalize expressions? + val children = af2.children + if (!distinctExpressionSchemaMap.contains(af2.children)) { + // TODO Typing? + val ref = AttributeReference("de", new OpenHashSetUDT(NullType), nullable = false)() + distinctExpressionSchemaMap += children -> ref + bufferSchema += ref + initialValues += NewSet(NullType) + if (children.size > 1) { + updateExpressions += CreateStruct(children) + } else { + updateExpressions += children.head + } + } + } + + // Add functions. + flattened.foreach { + case (agg: AlgebraicAggregate, true, _) => + val ref = distinctExpressionSchemaMap(agg.children) + evaluateExpressions += ReduceSetAlgebraic(ref, agg) + case (agg: AlgebraicAggregate, false, _) => + agg.bufferOffset = bufferSchema.size + bufferSchema ++= agg.bufferAttributes + initialValues ++= agg.initialValues + updateExpressions ++= agg.updateExpressions + evaluateExpressions += agg.evaluateExpression + case (agg: AggregateFunction2, true, _) => + val ref = distinctExpressionSchemaMap(agg.children) + evaluateExpressions += ReduceSetAggregate(ref, agg) + case (agg: AggregateFunction2, false, i) => + aggregates2 += agg + aggregates2OutputOffsets += i + agg.bufferOffset = bufferSchema.size + bufferSchema ++= agg.bufferAttributes + val nops = Seq.fill(agg.bufferAttributes.size)(NoOp) + initialValues ++= nops + updateExpressions ++= nops + evaluateExpressions += NoOp + case (agg: AggregateExpression1, false, i) => + aggregates1 += agg + aggregates1BufferOffsets += bufferSchema.size + aggregates1OutputOffsets += i + // TODO typing + bufferSchema += AttributeReference("agg", NullType, nullable = false)() + initialValues += NoOp + updateExpressions += NoOp + evaluateExpressions += NoOp + } + + // Create the projections. + val initialProjection = newMutableProjection(initialValues, Nil)() + val updateProjection = newMutableProjection(updateExpressions, bufferSchema ++ inputSchema)() + val evaluateProjection = newMutableProjection(evaluateExpressions, bufferSchema)() + + // Create the processor + new AggregateProcessor(bufferSchema.toArray, initialProjection, updateProjection, + evaluateProjection, aggregates2.toArray, aggregates2OutputOffsets.toArray, + aggregates1.toArray, aggregates1BufferOffsets.toArray, aggregates1OutputOffsets.toArray) + } +} + +/** + * This class manages the processing of a number of aggregate functions. See the documentation of + * the object for more information. + */ +private[execution] final class AggregateProcessor( + private[this] val bufferSchema: Array[AttributeReference], + private[this] val initialProjection: MutableProjection, + private[this] val updateProjection: MutableProjection, + private[this] val evaluateProjection: MutableProjection, + private[this] val aggregates2: Array[AggregateFunction2], + private[this] val aggregates2OutputOffsets: Array[Int], + private[this] val aggregates1: Array[AggregateExpression1], + private[this] val aggregates1BufferOffsets: Array[Int], + private[this] val aggregates1OutputOffsets: Array[Int]) { + + private[this] val join = new JoinedRow + private[this] val bufferSchemaSize = bufferSchema.length + private[this] val aggregates2Size = aggregates2.length + private[this] val aggregates1Size = aggregates1.length + + // Create the initial state + def initialize: MutableRow = { + val buffer = new GenericMutableRow(bufferSchemaSize) + initialProjection.target(buffer)(EmptyRow) + var i = 0 + while (i < aggregates2Size) { + aggregates2(i).initialize(buffer) + i += 1 + } + i = 0 + while (i < aggregates1Size) { + buffer(aggregates1BufferOffsets(i)) = aggregates1(i).newInstance() + i += 1 + } + buffer + } + + // Update the buffer. + def update(buffer: MutableRow, input: InternalRow): Unit = { + updateProjection.target(buffer)(join(buffer, input)) + var i = 0 + while (i < aggregates2Size) { + aggregates2(i).update(buffer, input) + i += 1 + } + i = 0 + while (i < aggregates1Size) { + buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(i)).update(input) + i += 1 + } + } + + // Evaluate buffer. + def evaluate(target: MutableRow, buffer: MutableRow): Unit = { + evaluateProjection.target(target)(buffer) + var i = 0 + while (i < aggregates2Size) { + val value = aggregates2(i).eval(buffer) + target.update(aggregates2OutputOffsets(i), value) + i += 1 + } + i = 0 + while (i < aggregates1Size) { + val value = buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(i)).eval(EmptyRow) + target.update(aggregates1OutputOffsets(i), value) + i += 1 + } + } +} - /** Copy the frame. */ - override def copy: UnboundedFollowingWindowFunctionFrame = - new UnboundedFollowingWindowFunctionFrame(ordinal, copyFunctions, lbound) +private[execution] final class OffsetMutableRow(offset: Int, delegate: MutableRow) + extends MutableRow { + def setNullAt(i: Int): Unit = delegate.setNullAt(i + offset) + def update(i: Int, value: Any): Unit = delegate.update(i + offset, value) + def get(i: Int): Any = delegate.get(i + offset) + def numFields: Int = delegate.numFields - offset } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index c3d224629702..53781f5c1266 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -139,36 +139,8 @@ class WindowSpec private[sql]( * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ private[sql] def withAggregate(aggregate: Column): Column = { - val windowExpr = aggregate.expr match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction("first_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction("last_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case wf: WindowFunction => WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case x => - throw new UnsupportedOperationException(s"$x is not supported in window operation.") - } + val windowExpr = WindowExpression(aggregate.expr, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) new Column(windowExpr) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 79c5f596661d..ddd0f9217983 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -346,7 +346,7 @@ object functions { * @since 1.4.0 */ def cumeDist(): Column = { - UnresolvedWindowFunction("cume_dist", Nil) + new CumeDist() } /** @@ -363,7 +363,7 @@ object functions { * @since 1.4.0 */ def denseRank(): Column = { - UnresolvedWindowFunction("dense_rank", Nil) + new DenseRank() } /** @@ -419,7 +419,7 @@ object functions { * @since 1.4.0 */ def lag(e: Column, offset: Int, defaultValue: Any): Column = { - UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + Lag(e.expr, offset, Literal(defaultValue)) } /** @@ -475,7 +475,7 @@ object functions { * @since 1.4.0 */ def lead(e: Column, offset: Int, defaultValue: Any): Column = { - UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + Lead(e.expr, offset, Literal(defaultValue)) } /** @@ -489,7 +489,7 @@ object functions { * @since 1.4.0 */ def ntile(n: Int): Column = { - UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) + new NTile(n) } /** @@ -506,7 +506,7 @@ object functions { * @since 1.4.0 */ def percentRank(): Column = { - UnresolvedWindowFunction("percent_rank", Nil) + new PercentRank() } /** @@ -523,7 +523,7 @@ object functions { * @since 1.4.0 */ def rank(): Column = { - UnresolvedWindowFunction("rank", Nil) + new Rank() } /** @@ -535,7 +535,7 @@ object functions { * @since 1.4.0 */ def rowNumber(): Column = { - UnresolvedWindowFunction("row_number", Nil) + RowNumber() } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index c3f29350101d..a0b7611ba478 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -331,6 +331,15 @@ private[hive] object HiveQl extends Logging { } /** Extractor for matching Hive's AST Tokens. */ + /** Extractor for matching Hive's AST Tokens. */ + private[hive] case class Token(name: String, children: Seq[ASTNode]) extends Node { + def getName(): String = name + def getChildren(): java.util.List[Node] = { + val col = new java.util.ArrayList[Node](children.size) + children.foreach(col.add(_)) + col + } + } object Token { /** @return matches of the form (tokenName, children). */ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { @@ -338,6 +347,7 @@ private[hive] object HiveQl extends Logging { CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + case t: Token => Some((t.name, t.children)) case _ => None } } @@ -1444,17 +1454,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) /* Window Functions */ - case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = UnresolvedWindowFunction(name, args.map(nodeToExpr)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - // Safe to use Literal(1)? - val function = UnresolvedWindowFunction(name, Literal(1) :: Nil) + case Token(name, args :+ Token("TOK_WINDOWSPEC", spec)) => + val function = nodeToExpr(Token(name, args)) nodesToWindowSpecification(spec) match { case reference: WindowSpecReference => UnresolvedWindowExpression(function, reference) From d348482679f6a7cde3c1cb5575eb0313d708dfd4 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 26 Jul 2015 00:45:33 -0400 Subject: [PATCH 03/19] Rebase & make it compile again. --- .../apache/spark/sql/catalyst/expressions/aggregate/sets.scala | 2 +- .../spark/sql/catalyst/expressions/windowExpressions.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala index 5f8e4618f3b5..67d01d63a3a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala @@ -52,7 +52,7 @@ case class ReduceSetAlgebraic(left: Expression, right: AlgebraicAggregate) @transient private[this] lazy val evaluate = BindReferences.bindReference(right.evaluateExpression, right.bufferSchema.toAttributes) - @transient private[this] lazy val joinRow = new JoinedRow4 + @transient private[this] lazy val joinRow = new JoinedRow override def eval(input: InternalRow): Any = { val result = left.eval(input).asInstanceOf[OpenHashSet[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 4a80100b9cfe..16b2eb341401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -400,7 +400,6 @@ abstract class AggregateWindowFunction extends AlgebraicAggregate with WindowFun self: Product => override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) override def dataType: DataType = IntegerType - override def foldable: Boolean = false override def nullable: Boolean = false override val mergeExpressions = Nil // TODO how to deal with this? } From 62c456f66e2d157012b51c6431a1e4b0b395c933 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 26 Jul 2015 12:56:36 -0400 Subject: [PATCH 04/19] Bug fixes.... --- .../expressions/windowExpressions.scala | 32 +++++++++++-------- .../apache/spark/sql/execution/Window.scala | 25 ++++++++------- .../org/apache/spark/sql/functions.scala | 4 +-- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 16b2eb341401..592b334bb941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -346,10 +346,12 @@ trait SizeBasedWindowFunction extends WindowFunction2 { def withSize(n: MutableLiteral): SizeBasedWindowFunction } -abstract class OffsetWindowFunction(child: Expression, offset: Int, default: Expression) +abstract class OffsetWindowFunction(child: Expression, offset: Expression, default: Expression) extends Expression with WindowFunction2 with CodegenFallback { self: Product => + require(offset.foldable, "Offset must be foldable") + override lazy val resolved = child.resolved && default.resolved && child.dataType == default.dataType @@ -368,32 +370,36 @@ abstract class OffsetWindowFunction(child: Expression, offset: Int, default: Exp } override def toString: String = s"$simpleString($child, $offset, $default)" + + protected val offsetVal = offset.eval().asInstanceOf[Int] } -case class Lead(child: Expression, offset: Int, default: Expression) +case class Lead(child: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction(child, offset, default) { - def this(child: Expression, offset: Int) = + def this() = this(null, null, null) + def this(child: Expression, offset: Expression) = this(child, offset, Literal.create(null, child.dataType)) def this(child: Expression) = - this(child, 1, Literal.create(null, child.dataType)) + this(child, Literal(1), Literal.create(null, child.dataType)) override val frame = SpecifiedWindowFrame(RowFrame, - ValueFollowing(offset), - ValueFollowing(offset)) + ValueFollowing(offsetVal), + ValueFollowing(offsetVal)) } -case class Lag(child: Expression, offset: Int, default: Expression) +case class Lag(child: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction(child, offset, default) { - def this(child: Expression, offset: Int) = + def this() = this(null, null, null) + def this(child: Expression, offset: Expression) = this(child, offset, Literal.create(null, child.dataType)) def this(child: Expression) = - this(child, 1, Literal.create(null, child.dataType)) + this(child, Literal(1), Literal.create(null, child.dataType)) override val frame = SpecifiedWindowFrame(RowFrame, - ValuePreceding(offset), - ValuePreceding(offset)) + ValuePreceding(offsetVal), + ValuePreceding(offsetVal)) } abstract class AggregateWindowFunction extends AlgebraicAggregate with WindowFunction2 { @@ -486,11 +492,11 @@ abstract class RankLike(order: Seq[SortOrder]) extends AggregateWindowFunction { protected val updateRank = If(And(orderEquals, rank !== 0), rank, doUpdateRank) // Implementation for RANK() - protected val doUpdateRank: Expression = rowNumber + 1L + protected val doUpdateRank: Expression = rowNumber + 1 override val bufferAttributes = rank +: rowNumber +: orderAttrs override val initialValues = Literal(0) +: Literal(0) +: orderInit override val updateExpressions = doUpdateRank +: (rowNumber + 1) +: orderExprs - override val evaluateExpression: Expression = Cast(rank, LongType) + override val evaluateExpression: Expression = Cast(rank, IntegerType) def withOrder(order: Seq[SortOrder]): RankLike } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 834cc8810004..b04b7fe0f127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -264,15 +264,13 @@ case class Window( // Track the unbound expressions unboundExpressions ++= unboundFrameExpressions - // Bind the expressions. + // Add ordering clause to ranking functions... Move code below to analyser? The dependency + // used in the pattern match might be to narrow (only RankLike and its subclasses). val functions = unboundFrameExpressions.map { e => - // Perhaps move code below to analyser. The dependency used in the pattern match might - // be to narrow (only RankLike and its subclasses). - val function = e.windowFunction match { + e.windowFunction match { case r: RankLike => r.withOrder(windowSpec.orderSpec) case f => f } - BindReferences.bindReference(function, child.output) }.toArray // Create the frame processor factory. @@ -763,9 +761,9 @@ private[execution] object AggregateProcessor { bufferSchema += ref initialValues += NewSet(NullType) if (children.size > 1) { - updateExpressions += CreateStruct(children) + updateExpressions += AddItemToSet(CreateStruct(children), ref) } else { - updateExpressions += children.head + updateExpressions += AddItemToSet(children.head, ref) } } } @@ -785,23 +783,26 @@ private[execution] object AggregateProcessor { val ref = distinctExpressionSchemaMap(agg.children) evaluateExpressions += ReduceSetAggregate(ref, agg) case (agg: AggregateFunction2, false, i) => - aggregates2 += agg + val boundAgg = BindReferences.bindReference(agg, inputSchema) + aggregates2 += boundAgg aggregates2OutputOffsets += i agg.bufferOffset = bufferSchema.size - bufferSchema ++= agg.bufferAttributes - val nops = Seq.fill(agg.bufferAttributes.size)(NoOp) + bufferSchema ++= boundAgg.bufferAttributes + val nops = Seq.fill(boundAgg.bufferAttributes.size)(NoOp) initialValues ++= nops updateExpressions ++= nops evaluateExpressions += NoOp case (agg: AggregateExpression1, false, i) => - aggregates1 += agg + aggregates1 += BindReferences.bindReference(agg, inputSchema) aggregates1BufferOffsets += bufferSchema.size aggregates1OutputOffsets += i - // TODO typing + // TODO typing - we would need to create UDT for this. bufferSchema += AttributeReference("agg", NullType, nullable = false)() initialValues += NoOp updateExpressions += NoOp evaluateExpressions += NoOp + case (agg, distinct, i) => + sys.error(s"Unsupported Aggregate $agg[distinct=$distinct, index=$i]") } // Create the projections. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ddd0f9217983..fba121699f52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -419,7 +419,7 @@ object functions { * @since 1.4.0 */ def lag(e: Column, offset: Int, defaultValue: Any): Column = { - Lag(e.expr, offset, Literal(defaultValue)) + Lag(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -475,7 +475,7 @@ object functions { * @since 1.4.0 */ def lead(e: Column, offset: Int, defaultValue: Any): Column = { - Lead(e.expr, offset, Literal(defaultValue)) + Lead(e.expr, Literal(offset), Literal(defaultValue)) } /** From bcff01b7d6ea978e1542e9c6a3761865ae1d4560 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 27 Jul 2015 15:30:47 -0400 Subject: [PATCH 05/19] Add NoOp to interpreted projections. --- .../spark/sql/catalyst/expressions/Projection.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index afe52e6a667e..540c1aec6fd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.types.{DataType, Decimal, StructType, _} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -43,7 +44,10 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { val outputArray = new Array[Any](exprArray.length) var i = 0 while (i < exprArray.length) { - outputArray(i) = exprArray(i).eval(input) + exprArray(i) match { + case NoOp => + case e => outputArray(i) = e.eval(input) + } i += 1 } new GenericInternalRow(outputArray) @@ -79,7 +83,10 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu override def apply(input: InternalRow): InternalRow = { var i = 0 while (i < exprArray.length) { - mutableRow(i) = exprArray(i).eval(input) + exprArray(i) match { + case NoOp => + case e => mutableRow(i) = e.eval(input) + } i += 1 } mutableRow From 98c11b688bc8e3b1550f7faf2ffa873ba17a2e0d Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 27 Jul 2015 23:52:33 -0400 Subject: [PATCH 06/19] Bugfixes... All tests work now. --- .../sql/catalyst/analysis/Analyzer.scala | 25 ++- .../expressions/windowExpressions.scala | 143 +++++++++--------- .../apache/spark/sql/execution/Window.scala | 102 ++++++++----- .../org/apache/spark/sql/functions.scala | 2 +- 4 files changed, 161 insertions(+), 111 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 98ab61c57648..bcb0e3a24f60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -76,6 +76,7 @@ class Analyzer( ResolveGenerate :: ResolveFunctions :: ResolveAliases :: + ResolveWindowOrder :: ResolveWindowFrame :: ExtractWindowExpressions :: GlobalAggregates :: @@ -526,6 +527,8 @@ class Analyzer( case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { + // DISTINCT is not meaningful in case of WindowFunctions. + case wf: WindowFunction2 => wf // We get an aggregate function built based on AggregateFunction2 interface. // So, we wrap it in AggregateExpression2. case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) @@ -1040,22 +1043,36 @@ class Analyzer( */ object ResolveWindowFrame extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case logical: LogicalPlan => logical.transformExpressionsDown { + case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction2, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) if wf.frame != UnspecifiedFrame && wf.frame != f => - failAnalysis(s"The frame of the window '$f' does not match the required frame " + - s"'${wf.frame}'") + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") case WindowExpression(wf: WindowFunction2, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if wf.frame != UnspecifiedFrame => WindowExpression(wf, s.copy(frameSpecification = wf.frame)) case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => - val frame = SpecifiedWindowFrame.defaultWindowFrame(!o.isEmpty, false) + val frame = SpecifiedWindowFrame.defaultWindowFrame(!o.isEmpty, true) we.copy(windowSpec = s.copy(frameSpecification = frame)) } } } + + /** + * Check and add order to [[AggregateWindowFunction]]s. + */ + object ResolveWindowOrder extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(agg: AggregateWindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"AggregateWindowFunction $agg window specification must be ordered") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 592b334bb941..815d0403d70e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.aggregate.AlgebraicAggregate -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -343,63 +342,63 @@ trait WindowFunction2 extends Expression { } trait SizeBasedWindowFunction extends WindowFunction2 { - def withSize(n: MutableLiteral): SizeBasedWindowFunction + def withSize(n: Expression): SizeBasedWindowFunction } -abstract class OffsetWindowFunction(child: Expression, offset: Expression, default: Expression) - extends Expression with WindowFunction2 with CodegenFallback { +abstract class OffsetWindowFunction + extends Expression with WindowFunction2 with Unevaluable with ImplicitCastInputTypes { self: Product => + val input: Expression + val default: Expression + val offset: Expression + val offsetSign: Int + def offsetValue: Int = offset.eval().asInstanceOf[Int] - require(offset.foldable, "Offset must be foldable") + override def children: Seq[Expression] = Seq(input, offset, default) - override lazy val resolved = - child.resolved && default.resolved && child.dataType == default.dataType + override def foldable: Boolean = input.foldable && (default == null || default.foldable) - override def children: Seq[Expression] = child :: default :: Nil + override def nullable: Boolean = input.nullable && (default == null || default.nullable) - override def dataType: DataType = child.dataType + override lazy val frame = { + val boundary = ValueFollowing(offsetSign * offsetValue) + SpecifiedWindowFrame(RowFrame, boundary, boundary) + } - override def foldable: Boolean = child.foldable && default.foldable + override def dataType: DataType = input.dataType - override def nullable: Boolean = child.nullable && default.nullable + override def inputTypes: Seq[AbstractDataType] = + Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType)) - override def eval(input: InternalRow): Any = { - val result = child.eval(input) - if (result != null) result - else default.eval(input) - } + override def toString: String = s"$prettyName($input, $offset, $default)" +} - override def toString: String = s"$simpleString($child, $offset, $default)" +case class Lead(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - protected val offsetVal = offset.eval().asInstanceOf[Int] -} + def this(input: Expression, offset: Expression) = + this(input, offset, Literal(null)) -case class Lead(child: Expression, offset: Expression, default: Expression) - extends OffsetWindowFunction(child, offset, default) { - def this() = this(null, null, null) - def this(child: Expression, offset: Expression) = - this(child, offset, Literal.create(null, child.dataType)) + def this(input: Expression) = + this(input, Literal(1), Literal(null)) - def this(child: Expression) = - this(child, Literal(1), Literal.create(null, child.dataType)) + def this() = this(Literal(null), Literal(1), Literal(null)) - override val frame = SpecifiedWindowFrame(RowFrame, - ValueFollowing(offsetVal), - ValueFollowing(offsetVal)) + val offsetSign = 1 } -case class Lag(child: Expression, offset: Expression, default: Expression) - extends OffsetWindowFunction(child, offset, default) { - def this() = this(null, null, null) - def this(child: Expression, offset: Expression) = - this(child, offset, Literal.create(null, child.dataType)) +case class Lag(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - def this(child: Expression) = - this(child, Literal(1), Literal.create(null, child.dataType)) + def this(input: Expression, offset: Expression) = + this(input, offset, Literal(null)) - override val frame = SpecifiedWindowFrame(RowFrame, - ValuePreceding(offsetVal), - ValuePreceding(offsetVal)) + def this(input: Expression) = + this(input, Literal(1), Literal(null)) + + def this() = this(Literal(null), Literal(1), Literal(null)) + + val offsetSign = -1 } abstract class AggregateWindowFunction extends AlgebraicAggregate with WindowFunction2 { @@ -412,7 +411,6 @@ abstract class AggregateWindowFunction extends AlgebraicAggregate with WindowFun abstract class RowNumberLike extends AggregateWindowFunction { override def children: Seq[Expression] = Nil - override def deterministic: Boolean = false override def inputTypes: Seq[AbstractDataType] = Nil protected val rowNumber = AttributeReference("rowNumber", IntegerType)() override val bufferAttributes: Seq[AttributeReference] = rowNumber :: Nil @@ -425,22 +423,23 @@ case class RowNumber() extends RowNumberLike { } // TODO check if this works in combination with CodeGeneration? -case class CumeDist(n: MutableLiteral) extends RowNumberLike with SizeBasedWindowFunction { - def this() = this(MutableLiteral(0, IntegerType)) +case class CumeDist(n: Expression) extends RowNumberLike with SizeBasedWindowFunction { + def this() = this(Literal(0)) override def dataType: DataType = DoubleType override def deterministic: Boolean = true - override def withSize(n: MutableLiteral): CumeDist = CumeDist(n) + override def withSize(n: Expression): CumeDist = CumeDist(n) override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - override val evaluateExpression = Cast(rowNumber / n, DoubleType) + override val evaluateExpression = Cast(rowNumber, DoubleType) / Cast(n, DoubleType) } // TODO check if this works in combination with CodeGeneration? // TODO check logic -case class NTile(n: MutableLiteral, buckets: Int) extends RowNumberLike +// Check serialization +case class NTile(buckets: Expression, n: Expression) extends RowNumberLike with SizeBasedWindowFunction { - require(buckets > 0, "Number of buckets must be > 0") - def this(buckets: Int) = this(MutableLiteral(0, IntegerType), buckets) - override def withSize(n: MutableLiteral): NTile = NTile(n, buckets) + def this() = this(Literal(1), Literal(0)) + def this(buckets: Expression) = this(buckets, Literal(0)) + override def withSize(n: Expression): NTile = NTile(buckets, n) private val bucket = AttributeReference("bucket", IntegerType)() private val bucketThreshold = AttributeReference("bucketThreshold", IntegerType)() private val bucketSize = AttributeReference("bucketSize", IntegerType)() @@ -464,9 +463,9 @@ case class NTile(n: MutableLiteral, buckets: Int) extends RowNumberLike override val updateExpressions = Seq( rowNumber + 1, - bucket + If(rowNumber >= bucketThreshold, 1, 0), + bucket + If(rowNumber > bucketThreshold, 1, 0), bucketThreshold + - If(rowNumber >= bucketThreshold, bucketSize + If(bucket <= bucketsWithPadding, 1, 0), 0), + If(rowNumber > bucketThreshold, bucketSize + If(bucket <= bucketsWithPadding, 1, 0), 0), bucketSize, bucketsWithPadding ) @@ -474,54 +473,54 @@ case class NTile(n: MutableLiteral, buckets: Int) extends RowNumberLike override val evaluateExpression = bucket } -abstract class RankLike(order: Seq[SortOrder]) extends AggregateWindowFunction { +abstract class RankLike extends AggregateWindowFunction { override def children: Seq[Expression] = order - override def deterministic: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - protected val orderExprs = order.map(_.expr) + override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) - protected val orderAttrs = orderExprs.zipWithIndex.map{ case (expr, i) => + val order: Seq[Expression] + protected val orderAttrs = order.zipWithIndex.map{ case (expr, i) => AttributeReference(i.toString, expr.dataType)() } - protected val orderEquals = orderExprs.zip(orderAttrs).map(EqualNullSafe.tupled).reduce(And) - protected val orderInit = orderExprs.map(e => Literal.create(null, e.dataType)) + protected val orderEquals = + order.zip(orderAttrs).map(EqualNullSafe.tupled).reduceOption(And).getOrElse(Literal(true)) + protected val orderInit = order.map(e => Literal.create(null, e.dataType)) protected val rank = AttributeReference("rank", IntegerType)() protected val rowNumber = AttributeReference("rowNumber", IntegerType)() - protected val updateRank = If(And(orderEquals, rank !== 0), rank, doUpdateRank) // Implementation for RANK() - protected val doUpdateRank: Expression = rowNumber + 1 + protected def doUpdateRank: Expression = rowNumber + 1 + protected def updateRank = If(And(orderEquals, rank !== 0), rank, doUpdateRank) override val bufferAttributes = rank +: rowNumber +: orderAttrs override val initialValues = Literal(0) +: Literal(0) +: orderInit - override val updateExpressions = doUpdateRank +: (rowNumber + 1) +: orderExprs + override val updateExpressions = updateRank +: (rowNumber + 1) +: order override val evaluateExpression: Expression = Cast(rank, IntegerType) - def withOrder(order: Seq[SortOrder]): RankLike + def withOrder(order: Seq[Expression]): RankLike } -case class Rank(order: Seq[SortOrder]) extends RankLike(order) { +case class Rank(order: Seq[Expression]) extends RankLike { def this() = this(Nil) - override def withOrder(order: Seq[SortOrder]): Rank = Rank(order) + override def withOrder(order: Seq[Expression]): Rank = Rank(order) } -case class DenseRank(order: Seq[SortOrder]) extends RankLike(order) { +case class DenseRank(order: Seq[Expression]) extends RankLike { def this() = this(Nil) - override def withOrder(order: Seq[SortOrder]): DenseRank = DenseRank(order) + override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) + override protected def doUpdateRank = rank + 1 + override val updateExpressions = updateRank +: order override val bufferAttributes = rank +: orderAttrs override val initialValues = Literal(0) +: orderInit - override val updateExpressions = doUpdateRank +: orderExprs - override val doUpdateRank = rank + 1 } // TODO check if this works in combination with CodeGeneration? -case class PercentRank(order: Seq[SortOrder], n: MutableLiteral) extends RankLike(order) +case class PercentRank(order: Seq[Expression], n: Expression) extends RankLike with SizeBasedWindowFunction { def this() = this(Nil, MutableLiteral(0, IntegerType)) - override def withOrder(order: Seq[SortOrder]): PercentRank = PercentRank(order, n) - override def withSize(n: MutableLiteral): PercentRank = PercentRank(order, n) + override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order, n) + override def withSize(n: Expression): PercentRank = PercentRank(order, n) override def dataType: DataType = DoubleType override val evaluateExpression = - If(n > 1, Cast((rank - 1) / (n - 1), DoubleType), Literal.create(0.0d, DoubleType)) + If(n > 1, Cast(rank - 1, DoubleType) / Cast(n - 1, DoubleType), + Literal.create(0.0d, DoubleType)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index b04b7fe0f127..a0a5ea824c8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.util +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -169,7 +170,7 @@ case class Window( * @return a frame processor. */ private[this] def createFrameProcessor( - frame: WindowFrame, + frame: (Char, WindowFrame), functions: Array[Expression], ordinal: Int, result: MutableRow, @@ -190,33 +191,43 @@ case class Window( // Create the frame processor. frame match { // Offset Frame - case SpecifiedWindowFrame(RowFrame, FrameBoundaryExtractor(l), FrameBoundaryExtractor(h)) + case ('O', SpecifiedWindowFrame(RowFrame, + FrameBoundaryExtractor(l), + FrameBoundaryExtractor(h))) if l == h => new OffsetWindowFunctionFrame(target, functions, child.output, newMutableProjection, l) // Growing Frame. - case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => + case ('A', SpecifiedWindowFrame(frameType, + UnboundedPreceding, + FrameBoundaryExtractor(high))) => val uBoundOrdering = createBoundOrdering(frameType, high) new UnboundedPrecedingWindowFunctionFrame(target, processor, uBoundOrdering) // Shrinking Frame. - case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => + case ('A', SpecifiedWindowFrame(frameType, + FrameBoundaryExtractor(low), + UnboundedFollowing)) => val lBoundOrdering = createBoundOrdering(frameType, low) new UnboundedFollowingWindowFunctionFrame(target, processor, lBoundOrdering) // Moving Frame. - case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(l), FrameBoundaryExtractor(h)) => + case ('A', SpecifiedWindowFrame(frameType, + FrameBoundaryExtractor(l), + FrameBoundaryExtractor(h))) => val lBoundOrdering = createBoundOrdering(frameType, l) val uBoundOrdering = createBoundOrdering(frameType, h) new SlidingWindowFunctionFrame(target, processor, lBoundOrdering, uBoundOrdering) // Entire Partition Frame. - case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => + case ('A', SpecifiedWindowFrame(_, + UnboundedPreceding, + UnboundedFollowing)) => new UnboundedWindowFunctionFrame(target, processor) // Error - case fr => - sys.error(s"Unsupported Frame $fr for functions: $functions") + case (mode, fr) => + sys.error(s"Unsupported Mode $mode Frame $fr for functions: $functions") } } @@ -242,17 +253,23 @@ case class Window( protected override def doExecute(): RDD[InternalRow] = { // Prepare processing. - // Group the window expression by their processing frame. + // Collect window expressions. val windowExprs = windowExpression.flatMap { _.collect { case e: WindowExpression => e } } + // Group the window expression by their processing frame and mode. + val framedWindowExprs = windowExprs.groupBy { + case e @ WindowExpression(_: AggregateExpression, spec) => ('A', spec.frameSpecification) + case e @ WindowExpression(_: AggregateFunction2, spec) => ('A', spec.frameSpecification) + case e @ WindowExpression(_, spec) => ('O', spec.frameSpecification) + } + // Create Frame processor factories and order the unbound window expressions by the frame they // are processed in; this is the order in which their results will be written to window // function result buffer. - val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) val factories = Array.ofDim[(MutableRow, MutableLiteral) => WindowFunctionFrame](framedWindowExprs.size) val unboundExpressions = mutable.Buffer.empty[Expression] @@ -264,14 +281,8 @@ case class Window( // Track the unbound expressions unboundExpressions ++= unboundFrameExpressions - // Add ordering clause to ranking functions... Move code below to analyser? The dependency - // used in the pattern match might be to narrow (only RankLike and its subclasses). - val functions = unboundFrameExpressions.map { e => - e.windowFunction match { - case r: RankLike => r.withOrder(windowSpec.orderSpec) - case f => f - } - }.toArray + // Extract functions from frame. + val functions = unboundFrameExpressions.map(_.windowFunction).toArray // Create the frame processor factory. factories(index) = (result: MutableRow, size: MutableLiteral) => @@ -431,36 +442,59 @@ private[execution] final class OffsetWindowFunctionFrame( /** Rows of the partition currently being processed. */ private[this] var input: CompactBuffer[InternalRow] = null - /** Index of the row we are currently using for output. */ + /** Index of the input row currently used for output. */ private[this] var inputIndex = 0 - /** Check if the output has been explicitly cleared. */ - private[this] var outputNull = false + /** Index of the current output row. */ + private[this] var outputIndex = 0 - /** Create a */ - private[this] val projection = newMutableProjection(expressions.toSeq, inputSchema)() - projection.target(target) + /** Row used when there is no valid input. */ + private[this] val emptyRow = new GenericInternalRow(inputSchema.size) + + /** Row used to combine the offset and the current row. */ + private[this] val join = new JoinedRow + + /** Create the projection. */ + private[this] val projection = { + // Create an input schema to bind the default expressions to. + val defaultInputSchema = inputSchema.map(_.newInstance()) ++ inputSchema + + // Collect the expressions and bind them. + val boundExpressions = expressions.toSeq.map { + case e: OffsetWindowFunction => + val boundLeft = BindReferences.bindReference(e.input, inputSchema) + if (e.default == null || e.default.foldable && e.default.eval() == null) { + // Without default value. + boundLeft + } else { + // With default value. + val boundRight = BindReferences.bindReference(e.default, defaultInputSchema) + Coalesce(boundLeft :: boundRight :: Nil) + } + case e => + BindReferences.bindReference(e, inputSchema) + } + + // Create the projection. + newMutableProjection(boundExpressions, Nil)().target(target) + } override def prepare(rows: CompactBuffer[InternalRow]): Unit = { input = rows inputIndex = offset + outputIndex = 0 } override def write(): Unit = { val size = input.size if (inputIndex >= 0 && inputIndex < size) { - projection(input(inputIndex)) - outputNull = false - } - else if (!outputNull) { - var i = 0 - while (i < expressions.length) { - target.setNullAt(i) - i += 1 - } - outputNull = true + join(input(inputIndex), input(outputIndex)) + } else { + join(emptyRow, input(outputIndex)) } + projection(join) inputIndex += 1 + outputIndex += 1 } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fba121699f52..99c7aec89572 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -489,7 +489,7 @@ object functions { * @since 1.4.0 */ def ntile(n: Int): Column = { - new NTile(n) + new NTile(Literal(n)) } /** From 07d565565324c8df833a060158b70d85b8073326 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 29 Jul 2015 10:27:32 -0400 Subject: [PATCH 07/19] Update after rebase. Removal of redundant code. Tweaks... --- .../apache/spark/sql/execution/Window.scala | 70 ++++++++++--------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index a0a5ea824c8e..edf564bb4a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.execution -import java.util - -import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.{StructType, NullType, IntegerType} +import org.apache.spark.sql.types.{DataType, NullType, IntegerType} import org.apache.spark.rdd.RDD import org.apache.spark.util.collection.CompactBuffer import scala.collection.mutable @@ -462,21 +459,21 @@ private[execution] final class OffsetWindowFunctionFrame( // Collect the expressions and bind them. val boundExpressions = expressions.toSeq.map { case e: OffsetWindowFunction => - val boundLeft = BindReferences.bindReference(e.input, inputSchema) + val input = BindReferences.bindReference(e.input, inputSchema) if (e.default == null || e.default.foldable && e.default.eval() == null) { // Without default value. - boundLeft + input } else { // With default value. - val boundRight = BindReferences.bindReference(e.default, defaultInputSchema) - Coalesce(boundLeft :: boundRight :: Nil) + val default = BindReferences.bindReference(e.default, defaultInputSchema) + Coalesce(input :: default :: Nil) } case e => BindReferences.bindReference(e, inputSchema) - } + } - // Create the projection. - newMutableProjection(boundExpressions, Nil)().target(target) + // Create the projection. + newMutableProjection(boundExpressions, Nil)().target(target) } override def prepare(rows: CompactBuffer[InternalRow]): Unit = { @@ -524,9 +521,6 @@ private[execution] final class SlidingWindowFunctionFrame( * current output row. */ private[this] var inputLowIndex = 0 - /** Buffer used for storing prepared input for the window functions. */ - private[this] val buffer = new util.ArrayDeque[InternalRow] - /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 @@ -536,7 +530,6 @@ private[execution] final class SlidingWindowFunctionFrame( inputHighIndex = 0 inputLowIndex = 0 outputIndex = 0 - buffer.clear() } /** Write the frame columns for the current row to the given target row. */ @@ -547,7 +540,6 @@ private[execution] final class SlidingWindowFunctionFrame( // the output row upper bound. while (inputHighIndex < input.size && ubound.compare(input, inputHighIndex, outputIndex) <= 0) { - buffer.offer(input(inputHighIndex)) inputHighIndex += 1 bufferUpdated = true } @@ -556,18 +548,14 @@ private[execution] final class SlidingWindowFunctionFrame( // the output row lower bound. while (inputLowIndex < inputHighIndex && lbound.compare(input, inputLowIndex, outputIndex) < 0) { - buffer.pop() inputLowIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - val iterator = buffer.iterator() val status = processor.initialize - while (iterator.hasNext) { - processor.update(status, iterator.next()) - } + processor.update(status, input, inputLowIndex, inputHighIndex) processor.evaluate(target, status) } @@ -597,10 +585,7 @@ private[execution] final class UnboundedWindowFunctionFrame( /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ override def prepare(rows: CompactBuffer[InternalRow]): Unit = { status = processor.initialize - val iterator = rows.iterator - while (iterator.hasNext) { - processor.update(status, iterator.next()) - } + processor.update(status, rows, 0, rows.size) } /** Write the frame columns for the current row to the given target row. */ @@ -723,13 +708,8 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - var i = inputIndex - val size = input.size val status = processor.initialize - while (i < size) { - processor.update(status, input(i)) - i += 1 - } + processor.update(status, input, inputIndex, input.size) processor.evaluate(target, status) } @@ -898,7 +878,28 @@ private[execution] final class AggregateProcessor( } i = 0 while (i < aggregates1Size) { - buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(i)).update(input) + buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(i), null).update(input) + i += 1 + } + } + + /** Bulk update the given buffer. */ + def update(buffer: MutableRow, input: CompactBuffer[InternalRow], begin: Int, end: Int): Unit = { + updateProjection.target(buffer) + var i = begin + while (i < end) { + val row = input(i) + updateProjection(join(buffer, row)) + var j = 0 + while (j < aggregates2Size) { + aggregates2(j).update(buffer, row) + j += 1 + } + j = 0 + while (j < aggregates1Size) { + buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(j), null).update(row) + j += 1 + } i += 1 } } @@ -914,7 +915,8 @@ private[execution] final class AggregateProcessor( } i = 0 while (i < aggregates1Size) { - val value = buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(i)).eval(EmptyRow) + val function = buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(i), null) + val value = function.eval(EmptyRow) target.update(aggregates1OutputOffsets(i), value) i += 1 } @@ -925,6 +927,6 @@ private[execution] final class OffsetMutableRow(offset: Int, delegate: MutableRo extends MutableRow { def setNullAt(i: Int): Unit = delegate.setNullAt(i + offset) def update(i: Int, value: Any): Unit = delegate.update(i + offset, value) - def get(i: Int): Any = delegate.get(i + offset) + def get(i: Int, dataType: DataType): Any = delegate.get(i + offset, dataType) def numFields: Int = delegate.numFields - offset } From be5199cd76adfd86269259201a587c9e69048b1a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 29 Jul 2015 21:01:24 -0400 Subject: [PATCH 08/19] Tiny touchup. Make Aggregate Processor a bit more DRYly --- .../org/apache/spark/sql/execution/Window.scala | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index edf564bb4a7e..20a0e168f8be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -885,21 +885,9 @@ private[execution] final class AggregateProcessor( /** Bulk update the given buffer. */ def update(buffer: MutableRow, input: CompactBuffer[InternalRow], begin: Int, end: Int): Unit = { - updateProjection.target(buffer) var i = begin while (i < end) { - val row = input(i) - updateProjection(join(buffer, row)) - var j = 0 - while (j < aggregates2Size) { - aggregates2(j).update(buffer, row) - j += 1 - } - j = 0 - while (j < aggregates1Size) { - buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(j), null).update(row) - j += 1 - } + update(buffer, input(i)) i += 1 } } From a21d67d3ee4df2bf225adfb015cded8bb727c193 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 30 Jul 2015 23:05:25 -0400 Subject: [PATCH 09/19] Rebase. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../apache/spark/sql/catalyst/expressions/aggregate/sets.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bcb0e3a24f60..46abf8d0e1b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1038,7 +1038,7 @@ class Analyzer( } } - /* + /** * Check and add proper window frames for all window functions. */ object ResolveWindowFrame extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala index 67d01d63a3a5..d70a75b0b5f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala @@ -79,7 +79,7 @@ case class ReduceSetAlgebraic(left: Expression, right: AlgebraicAggregate) case class ReduceSetAggregate(left: Expression, right: AggregateFunction2) extends BinaryExpression with CodegenFallback { - right.bufferOffset = 0 + right.mutableBufferOffset = 0 override def dataType: DataType = right.dataType From d27257cbc0b59e2f7b07996ce21c966ed7c0ebed Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 30 Jul 2015 23:08:28 -0400 Subject: [PATCH 10/19] Rebase. Performance improvements... --- .../catalyst/expressions/BoundAttribute.scala | 66 +++++++++++++++++-- .../apache/spark/sql/execution/Window.scala | 14 ++-- 2 files changed, 69 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 473b9b787058..e8fe09fe3963 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -28,13 +28,20 @@ import org.apache.spark.sql.types._ * to be retrieved more efficiently. However, since operations like column pruning can change * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ -case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends LeafExpression with NamedExpression { +abstract class AbstractBoundReference extends LeafExpression with NamedExpression { + val ordinal: Int + + protected[this] def prefix: String = "" + + protected[this] def genCodeInput = "i" - override def toString: String = s"input[$ordinal, $dataType]" + protected[this] def unwrap(input: InternalRow): InternalRow = input + + override def toString: String = s"${prefix}input[$ordinal, $dataType]" // Use special getter for primitive types (for UnsafeRow) - override def eval(input: InternalRow): Any = { + override def eval(i: InternalRow): Any = { + val input = unwrap(i) if (input.isNullAt(ordinal)) { null } else { @@ -68,14 +75,35 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) - val value = ctx.getValue("i", dataType, ordinal.toString) + val value = ctx.getValue(genCodeInput, dataType, ordinal.toString) s""" - boolean ${ev.isNull} = i.isNullAt($ordinal); + boolean ${ev.isNull} = $genCodeInput.isNullAt($ordinal); $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """ } } +case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) + extends AbstractBoundReference + +case class LeftBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) + extends AbstractBoundReference { + override protected def prefix = "left" + override protected def genCodeInput = + "((org.apache.spark.sql.catalyst.expressions.JoinedRow)i).left()" + override protected def unwrap(input: InternalRow): InternalRow = + input.asInstanceOf[JoinedRow].left +} + +case class RightBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) + extends AbstractBoundReference { + override protected def prefix = "right" + override protected def genCodeInput = + "((org.apache.spark.sql.catalyst.expressions.JoinedRow)i).right()" + override protected def unwrap(input: InternalRow): InternalRow = + input.asInstanceOf[JoinedRow].right +} + object BindReferences extends Logging { def bindReference[A <: Expression]( @@ -97,4 +125,30 @@ object BindReferences extends Logging { } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } + + def createJoinReferenceMap(left: Seq[Attribute], right: Seq[Attribute]): + Map[ExprId, AbstractBoundReference] = { + (left.zipWithIndex.map { + case (e, ordinal) => + (e.exprId, LeftBoundReference(ordinal, e.dataType, e.nullable)) + } ++ right.zipWithIndex.map { + case (e, ordinal) => + (e.exprId, RightBoundReference(ordinal, e.dataType, e.nullable)) + }).toMap + } + + def bindJoinReferences( + expressions: Seq[Expression], + left: Seq[Attribute], + right: Seq[Attribute]): Seq[Expression] = { + val refMap = createJoinReferenceMap(left, right) + expressions.map { expression => + expression.transform { case a: AttributeReference => + attachTree(a, "Binding attribute") { + refMap.getOrElse(a.exprId, sys.error(s"Couldn't find $a in left " + + s"${left.mkString("[", ",", "]")} or right ${right.mkString("[", ",", "]")}")) + } + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 20a0e168f8be..40499c64105f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -788,7 +788,7 @@ private[execution] object AggregateProcessor { val ref = distinctExpressionSchemaMap(agg.children) evaluateExpressions += ReduceSetAlgebraic(ref, agg) case (agg: AlgebraicAggregate, false, _) => - agg.bufferOffset = bufferSchema.size + agg.mutableBufferOffset = bufferSchema.size bufferSchema ++= agg.bufferAttributes initialValues ++= agg.initialValues updateExpressions ++= agg.updateExpressions @@ -800,7 +800,7 @@ private[execution] object AggregateProcessor { val boundAgg = BindReferences.bindReference(agg, inputSchema) aggregates2 += boundAgg aggregates2OutputOffsets += i - agg.bufferOffset = bufferSchema.size + agg.mutableBufferOffset = bufferSchema.size bufferSchema ++= boundAgg.bufferAttributes val nops = Seq.fill(boundAgg.bufferAttributes.size)(NoOp) initialValues ++= nops @@ -821,9 +821,13 @@ private[execution] object AggregateProcessor { // Create the projections. val initialProjection = newMutableProjection(initialValues, Nil)() - val updateProjection = newMutableProjection(updateExpressions, bufferSchema ++ inputSchema)() val evaluateProjection = newMutableProjection(evaluateExpressions, bufferSchema)() + // (EXPERI)-MENTAL + val boundUpdateExpressions = BindReferences.bindJoinReferences( + updateExpressions, bufferSchema, inputSchema) + val updateProjection = newMutableProjection(boundUpdateExpressions, Nil)() + // Create the processor new AggregateProcessor(bufferSchema.toArray, initialProjection, updateProjection, evaluateProjection, aggregates2.toArray, aggregates2OutputOffsets.toArray, @@ -847,13 +851,13 @@ private[execution] final class AggregateProcessor( private[this] val aggregates1OutputOffsets: Array[Int]) { private[this] val join = new JoinedRow - private[this] val bufferSchemaSize = bufferSchema.length + private[this] val bufferDataTypes = bufferSchema.toSeq.map(_.dataType) private[this] val aggregates2Size = aggregates2.length private[this] val aggregates1Size = aggregates1.length // Create the initial state def initialize: MutableRow = { - val buffer = new GenericMutableRow(bufferSchemaSize) + val buffer = new SpecificMutableRow(bufferDataTypes) initialProjection.target(buffer)(EmptyRow) var i = 0 while (i < aggregates2Size) { From 25c6f42e405b696a25d6691d07a54da976ac83bc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 31 Jul 2015 09:39:32 -0400 Subject: [PATCH 11/19] Rebase. Performance improvements... --- .../spark/sql/execution/SparkStrategies.scala | 1 - .../apache/spark/sql/execution/Window.scala | 150 +++++++++++++++++- 2 files changed, 143 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8da3c35c3298..517e82590ed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -344,7 +344,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } - case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => val convertedWindowExpressions = windowExprs.map { e => val converted = e.transformDown(Utils.convertAggregateExpressions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 40499c64105f..6b015561de77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -22,8 +22,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.{DataType, NullType, IntegerType} +import org.apache.spark.sql.types.{Decimal, DataType, NullType, IntegerType} import org.apache.spark.rdd.RDD +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer import scala.collection.mutable @@ -824,14 +825,16 @@ private[execution] object AggregateProcessor { val evaluateProjection = newMutableProjection(evaluateExpressions, bufferSchema)() // (EXPERI)-MENTAL - val boundUpdateExpressions = BindReferences.bindJoinReferences( - updateExpressions, bufferSchema, inputSchema) - val updateProjection = newMutableProjection(boundUpdateExpressions, Nil)() + //val boundUpdateExpressions = BindReferences.bindJoinReferences( + // updateExpressions, bufferSchema, inputSchema) + val updateProjection = newMutableProjection(updateExpressions, bufferSchema ++ inputSchema)() + val join = new JRow(bufferSchema.size, bufferSchema.size + inputSchema.size) // Create the processor new AggregateProcessor(bufferSchema.toArray, initialProjection, updateProjection, evaluateProjection, aggregates2.toArray, aggregates2OutputOffsets.toArray, - aggregates1.toArray, aggregates1BufferOffsets.toArray, aggregates1OutputOffsets.toArray) + aggregates1.toArray, aggregates1BufferOffsets.toArray, aggregates1OutputOffsets.toArray, + join) } } @@ -848,9 +851,10 @@ private[execution] final class AggregateProcessor( private[this] val aggregates2OutputOffsets: Array[Int], private[this] val aggregates1: Array[AggregateExpression1], private[this] val aggregates1BufferOffsets: Array[Int], - private[this] val aggregates1OutputOffsets: Array[Int]) { + private[this] val aggregates1OutputOffsets: Array[Int], + private[this] val join: JRow) { - private[this] val join = new JoinedRow + //private[this] val join = new JoinedRow private[this] val bufferDataTypes = bufferSchema.toSeq.map(_.dataType) private[this] val aggregates2Size = aggregates2.length private[this] val aggregates1Size = aggregates1.length @@ -922,3 +926,135 @@ private[execution] final class OffsetMutableRow(offset: Int, delegate: MutableRo def get(i: Int, dataType: DataType): Any = delegate.get(i + offset, dataType) def numFields: Int = delegate.numFields - offset } + + +final class JRow(leftNumFields: Int, totalNumFields: Int) extends InternalRow { +/* + private[this] val mapping = Array.tabulate(totalNumFields) { n => + if (n < leftNumFields) 0 + else 1 + } + private[this] val ordinals = Array.tabulate(totalNumFields) { n => + if (n < leftNumFields) n + else n - leftNumFields + }*/ + // Get the sign and flip the bit. + private[this] def row(i:Int) = (((i - leftNumFields) & -0x80000000) >>> 31) ^ 1 + + // + private[this] def ordinal(i:Int, row: Int) = i - row * leftNumFields + + private[this] val rows = new Array[InternalRow](2) + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { + rows(0) = r1 + rows(1) = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: InternalRow): InternalRow = { + rows(0) = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: InternalRow): InternalRow = { + rows(1) = newRight + this + } + + override def toSeq: Seq[Any] = rows.flatMap(_.toSeq) + + override def numFields: Int = totalNumFields + + override def getUTF8String(i: Int): UTF8String = { + val r = row(i) + rows(r).getUTF8String(ordinal(i, r)) + } + + override def getBinary(i: Int): Array[Byte] = { + val r = row(i) + rows(r).getBinary(ordinal(i, r)) + } + + override def get(i: Int, dataType: DataType): Any = { + val r = row(i) + rows(r).get(ordinal(i, r), dataType) + } + + override def isNullAt(i: Int): Boolean = { + val r = row(i) + rows(r).isNullAt(ordinal(i, r)) + } + + override def getInt(i: Int): Int = { + val r = row(i) + rows(r).getInt(ordinal(i, r)) + } + + override def getLong(i: Int): Long = { + val r = row(i) + rows(r).getLong(ordinal(i, r)) + } + + override def getDouble(i: Int): Double = { + val r = row(i) + rows(r).getDouble(ordinal(i, r)) + } + + override def getBoolean(i: Int): Boolean = { + val r = row(i) + rows(r).getBoolean(ordinal(i, r)) + } + + override def getShort(i: Int): Short = { + val r = row(i) + rows(r).getShort(ordinal(i, r)) + } + + override def getByte(i: Int): Byte = { + val r = row(i) + rows(r).getByte(ordinal(i, r)) + } + + override def getFloat(i: Int): Float = { + val r = row(i) + rows(r).getFloat(ordinal(i, r)) + } + + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + val r = row(i) + rows(r).getDecimal(ordinal(i, r), precision, scale) + } + + override def getStruct(i: Int, numFields: Int): InternalRow = { + val r = row(i) + rows(r).getStruct(ordinal(i, r), numFields) + } + + override def copy(): InternalRow = { + val copiedValues = new Array[Any](totalNumFields) + var i = 0 + while (i < totalNumFields) { + copiedValues(i) = get(i) + i += 1 + } + new GenericInternalRow(copiedValues) + } + + override def toString: String = { + // Make sure toString never throws NullPointerException. + val Array(row1, row2) = rows + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } + } +} From 1352d26a1552bf6463683bac3e0b745fc19fcbff Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 31 Jul 2015 16:06:33 -0400 Subject: [PATCH 12/19] More performance tweaks. --- .../apache/spark/sql/execution/Window.scala | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 6b015561de77..dfda1a291027 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -825,10 +825,12 @@ private[execution] object AggregateProcessor { val evaluateProjection = newMutableProjection(evaluateExpressions, bufferSchema)() // (EXPERI)-MENTAL - //val boundUpdateExpressions = BindReferences.bindJoinReferences( - // updateExpressions, bufferSchema, inputSchema) - val updateProjection = newMutableProjection(updateExpressions, bufferSchema ++ inputSchema)() - val join = new JRow(bufferSchema.size, bufferSchema.size + inputSchema.size) + val boundUpdateExpressions = BindReferences.bindJoinReferences( + updateExpressions, bufferSchema, inputSchema) + //val updateProjection = newMutableProjection(updateExpressions, bufferSchema ++ inputSchema)() + val updateProjection = newMutableProjection(boundUpdateExpressions, Nil)() + val join = new JoinedRow + //val join = new JRow(bufferSchema.size, bufferSchema.size + inputSchema.size) // Create the processor new AggregateProcessor(bufferSchema.toArray, initialProjection, updateProjection, @@ -852,7 +854,7 @@ private[execution] final class AggregateProcessor( private[this] val aggregates1: Array[AggregateExpression1], private[this] val aggregates1BufferOffsets: Array[Int], private[this] val aggregates1OutputOffsets: Array[Int], - private[this] val join: JRow) { + private[this] val join: JoinedRow) { //private[this] val join = new JoinedRow private[this] val bufferDataTypes = bufferSchema.toSeq.map(_.dataType) @@ -928,21 +930,24 @@ private[execution] final class OffsetMutableRow(offset: Int, delegate: MutableRo } -final class JRow(leftNumFields: Int, totalNumFields: Int) extends InternalRow { -/* - private[this] val mapping = Array.tabulate(totalNumFields) { n => - if (n < leftNumFields) 0 - else 1 - } - private[this] val ordinals = Array.tabulate(totalNumFields) { n => - if (n < leftNumFields) n - else n - leftNumFields - }*/ - // Get the sign and flip the bit. - private[this] def row(i:Int) = (((i - leftNumFields) & -0x80000000) >>> 31) ^ 1 +final class JRow(private[this] val numLeftFields: Int, val numFields: Int) extends InternalRow { + + /** + * Determine the index of the which row maps to the given ordinal. This method has been + * implemented using bitwise operations in order to prevent expensive branching. + * + * The key idea here is that the given ordinal is subtracted by the number of fields in the left + * row, and we then use the sign of that calculation to determine the index of the row. + * + * @param i ordinal to find the row for. + * @return index of the row that belongs to the given ordinal. + */ + @inline + private[this] def row(i:Int) = (((i - numLeftFields) & -0x80000000) >>> 31) ^ 1 - // - private[this] def ordinal(i:Int, row: Int) = i - row * leftNumFields + /** Determine the row ordinal given the row index and the input ordinal. */ + @inline + private[this] def ordinal(i:Int, row: Int) = i - row * numLeftFields private[this] val rows = new Array[InternalRow](2) @@ -967,8 +972,6 @@ final class JRow(leftNumFields: Int, totalNumFields: Int) extends InternalRow { override def toSeq: Seq[Any] = rows.flatMap(_.toSeq) - override def numFields: Int = totalNumFields - override def getUTF8String(i: Int): UTF8String = { val r = row(i) rows(r).getUTF8String(ordinal(i, r)) @@ -1035,9 +1038,9 @@ final class JRow(leftNumFields: Int, totalNumFields: Int) extends InternalRow { } override def copy(): InternalRow = { - val copiedValues = new Array[Any](totalNumFields) + val copiedValues = new Array[Any](numFields) var i = 0 - while (i < totalNumFields) { + while (i < numFields) { copiedValues(i) = get(i) i += 1 } From d91bcae21fe7ceb197691d2f9e975d4920307f28 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 3 Aug 2015 08:31:15 -0400 Subject: [PATCH 13/19] Further tweaks --- .../catalyst/expressions/BoundAttribute.scala | 11 +- .../expressions/codegen/CodeGenerator.scala | 16 ++ .../codegen/GenerateMutableProjection.scala | 1 + .../codegen/GenerateProjection.scala | 1 + .../codegen/GenerateUnsafeProjection.scala | 149 ++++++++++++++++++ 5 files changed, 174 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index e8fe09fe3963..401914219c11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -35,6 +35,8 @@ abstract class AbstractBoundReference extends LeafExpression with NamedExpressio protected[this] def genCodeInput = "i" + protected[this] def join = false + protected[this] def unwrap(input: InternalRow): InternalRow = input override def toString: String = s"${prefix}input[$ordinal, $dataType]" @@ -74,6 +76,7 @@ abstract class AbstractBoundReference extends LeafExpression with NamedExpressio override def exprId: ExprId = throw new UnsupportedOperationException override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ctx.join = join val javaType = ctx.javaType(dataType) val value = ctx.getValue(genCodeInput, dataType, ordinal.toString) s""" @@ -88,18 +91,18 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case class LeftBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends AbstractBoundReference { + override protected def join = true override protected def prefix = "left" - override protected def genCodeInput = - "((org.apache.spark.sql.catalyst.expressions.JoinedRow)i).left()" + override protected def genCodeInput = "left" override protected def unwrap(input: InternalRow): InternalRow = input.asInstanceOf[JoinedRow].left } case class RightBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends AbstractBoundReference { + override protected def join = true override protected def prefix = "right" - override protected def genCodeInput = - "((org.apache.spark.sql.catalyst.expressions.JoinedRow)i).right()" + override protected def genCodeInput = "right" override protected def unwrap(input: InternalRow): InternalRow = input.asInstanceOf[JoinedRow].right } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7b41c9a3f3b8..1c9d43d07552 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -90,6 +90,22 @@ class CodeGenContext { addedFuntions += ((funcName, funcCode)) } + /** + * Set to true when we are expecting a JoinedRow as input. + */ + var join: Boolean = false + def unwrapJoinRow: String = { + if (join) { + // We could also add the left()/right() methods to the InternalRow interface. + """ + |org.apache.spark.sql.catalyst.expressions.JoinedRow join = (org.apache.spark.sql.catalyst.expressions.JoinedRow) i; + |InternalRow left = join.left(); + |InternalRow right = join.right(); + """.stripMargin + } + else "" + } + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ac58423cd884..dc49694594e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -127,6 +127,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu public Object apply(Object _i) { InternalRow i = (InternalRow) _i; + ${ctx.unwrapJoinRow} $projectionCalls return mutableRow; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c744e84d822e..18951449b401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -179,6 +179,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { $columns public SpecificRow(InternalRow i) { + ${ctx.unwrapJoinRow} $initColumns } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index d8912df694a1..55377abc1560 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -116,6 +116,155 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") } + /** + * Generates the code to create an [[UnsafeRow]] object based on the input expressions. + * @param ctx context for code generation + * @param ev specifies the name of the variable for the output [[UnsafeRow]] object + * @param expressions input expressions + * @return generated code to put the expression output into an [[UnsafeRow]] + */ + def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression]) + : String = { + + val ret = ev.primitive + ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + val numBytes = ctx.freshName("numBytes") + + val exprs = expressions.map { e => e.dataType match { + case st: StructType => createCodeForStruct(ctx, e.gen(ctx), st) + case _ => e.gen(ctx) + }} + val allExprs = exprs.map(_.code).mkString("\n") + + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val additionalSize = expressions.zipWithIndex.map { + case (e, i) => genAdditionalSize(e.dataType, exprs(i)) + }.mkString("") + + val writers = expressions.zipWithIndex.map { case (e, i) => + val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor) + s"""if (${exprs(i).isNull}) { + $ret.setNullAt($i); + } else { + $update; + }""" + }.mkString("\n ") + + s""" + ${ctx.unwrapJoinRow} + $allExprs + int $numBytes = $fixedSize $additionalSize; + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; + } + + $ret.pointTo( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET, + ${expressions.size}, + $numBytes); + int $cursor = $fixedSize; + + $writers + boolean ${ev.isNull} = false; + """ + } + + /** + * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. + * + * This function also handles nested structs by recursively generating the code to do conversion. + * + * @param ctx code generation context + * @param input the input struct, identified by a [[GeneratedExpressionCode]] + * @param schema schema of the struct field + */ + // TODO: refactor createCode and this function to reduce code duplication. + private def createCodeForStruct( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + schema: StructType): GeneratedExpressionCode = { + + val isNull = input.isNull + val primitive = ctx.freshName("structConvert") + ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + + val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { + case (dt, i) => dt match { + case st: StructType => + val nestedStructEv = GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" + ) + createCodeForStruct(ctx, nestedStructEv, st) + case _ => + GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" + ) + } + } + val allExprs = exprs.map(_.code).mkString("\n") + + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => + genAdditionalSize(dt, ev) + }.mkString("") + + val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => + val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor) + s""" + if (${exprs(i).isNull}) { + $primitive.setNullAt($i); + } else { + $update; + } + """ + }.mkString("\n ") + + // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, + // just copy the bytes directly into our buffer space without running any conversion. + // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from + // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow. + val tmp = ctx.freshName("tmp") + val numBytes = ctx.freshName("numBytes") + val code = s""" + |${input.code} + |if (!${input.isNull}) { + | Object $tmp = (Object) ${input.primitive}; + | if ($tmp instanceof UnsafeRow) { + | $primitive = (UnsafeRow) $tmp; + | } else { + | $allExprs + | + | int $numBytes = $fixedSize $additionalSize; + | if ($numBytes > $buffer.length) { + | $buffer = new byte[$numBytes]; + | } + | + | $primitive.pointTo( + | $buffer, + | $PlatformDependent.BYTE_ARRAY_OFFSET, + | ${exprs.size}, + | $numBytes); + | int $cursor = $fixedSize; + | + | $writers + | } + |} + """.stripMargin + + GeneratedExpressionCode(code, isNull, primitive) + } + /** * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. * From e792325b636dbcfec0725636a68030db0f95b2fa Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 3 Aug 2015 11:43:36 -0400 Subject: [PATCH 14/19] Rebase... --- .../catalyst/expressions/aggregate/sets.scala | 2 +- .../apache/spark/sql/execution/Window.scala | 34 +++++++++++-------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala index d70a75b0b5f7..e78f1a8f753e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/sets.scala @@ -79,7 +79,7 @@ case class ReduceSetAlgebraic(left: Expression, right: AlgebraicAggregate) case class ReduceSetAggregate(left: Expression, right: AggregateFunction2) extends BinaryExpression with CodegenFallback { - right.mutableBufferOffset = 0 + right.withNewMutableBufferOffset(0) override def dataType: DataType = right.dataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index dfda1a291027..3e8ac83df8ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -789,7 +789,7 @@ private[execution] object AggregateProcessor { val ref = distinctExpressionSchemaMap(agg.children) evaluateExpressions += ReduceSetAlgebraic(ref, agg) case (agg: AlgebraicAggregate, false, _) => - agg.mutableBufferOffset = bufferSchema.size + agg.withNewMutableBufferOffset(bufferSchema.size) bufferSchema ++= agg.bufferAttributes initialValues ++= agg.initialValues updateExpressions ++= agg.updateExpressions @@ -801,7 +801,7 @@ private[execution] object AggregateProcessor { val boundAgg = BindReferences.bindReference(agg, inputSchema) aggregates2 += boundAgg aggregates2OutputOffsets += i - agg.mutableBufferOffset = bufferSchema.size + agg.withNewMutableBufferOffset(bufferSchema.size) bufferSchema ++= boundAgg.bufferAttributes val nops = Seq.fill(boundAgg.bufferAttributes.size)(NoOp) initialValues ++= nops @@ -888,7 +888,7 @@ private[execution] final class AggregateProcessor( } i = 0 while (i < aggregates1Size) { - buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(i), null).update(input) + buffer.get(aggregates1BufferOffsets(i), null).asInstanceOf[AggregateFunction1].update(input) i += 1 } } @@ -913,7 +913,7 @@ private[execution] final class AggregateProcessor( } i = 0 while (i < aggregates1Size) { - val function = buffer.getAs[AggregateFunction1](aggregates1BufferOffsets(i), null) + val function = buffer.get(aggregates1BufferOffsets(i), null).asInstanceOf[AggregateFunction1] val value = function.eval(EmptyRow) target.update(aggregates1OutputOffsets(i), value) i += 1 @@ -925,8 +925,18 @@ private[execution] final class OffsetMutableRow(offset: Int, delegate: MutableRo extends MutableRow { def setNullAt(i: Int): Unit = delegate.setNullAt(i + offset) def update(i: Int, value: Any): Unit = delegate.update(i + offset, value) - def get(i: Int, dataType: DataType): Any = delegate.get(i + offset, dataType) + def genericGet(i: Int): Any = delegate.genericGet(i + offset) def numFields: Int = delegate.numFields - offset + def copy(): InternalRow = { + val numFields = delegate.numFields + val values = new Array[Any](numFields) + var i = 0 + while (i < numFields) { + values(i) = delegate.genericGet(i) + i += 1 + } + new OffsetMutableRow(offset, new GenericMutableRow(values)) + } } @@ -982,9 +992,9 @@ final class JRow(private[this] val numLeftFields: Int, val numFields: Int) exten rows(r).getBinary(ordinal(i, r)) } - override def get(i: Int, dataType: DataType): Any = { + override def genericGet(i: Int): Any = { val r = row(i) - rows(r).get(ordinal(i, r), dataType) + rows(r).genericGet(ordinal(i, r)) } override def isNullAt(i: Int): Boolean = { @@ -1038,13 +1048,9 @@ final class JRow(private[this] val numLeftFields: Int, val numFields: Int) exten } override def copy(): InternalRow = { - val copiedValues = new Array[Any](numFields) - var i = 0 - while (i < numFields) { - copiedValues(i) = get(i) - i += 1 - } - new GenericInternalRow(copiedValues) + val row = new JRow(numLeftFields, numFields) + row(rows(0).copy(), rows(1).copy()) + row } override def toString: String = { From 47b09239be23e4f8d523c6766f94a85d9646bebd Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 5 Aug 2015 23:00:35 -0400 Subject: [PATCH 15/19] Revert performance improvements (moved them to SPARK-9357). --- .../catalyst/expressions/BoundAttribute.scala | 75 ++------- .../apache/spark/sql/execution/Window.scala | 146 +----------------- 2 files changed, 13 insertions(+), 208 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 401914219c11..61e23ab8019e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -28,22 +28,13 @@ import org.apache.spark.sql.types._ * to be retrieved more efficiently. However, since operations like column pruning can change * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ -abstract class AbstractBoundReference extends LeafExpression with NamedExpression { - val ordinal: Int - - protected[this] def prefix: String = "" - - protected[this] def genCodeInput = "i" - - protected[this] def join = false - - protected[this] def unwrap(input: InternalRow): InternalRow = input +case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) + extends LeafExpression with NamedExpression { - override def toString: String = s"${prefix}input[$ordinal, $dataType]" + override def toString: String = s"input[$ordinal, $dataType]" // Use special getter for primitive types (for UnsafeRow) - override def eval(i: InternalRow): Any = { - val input = unwrap(i) + override def eval(input: InternalRow): Any = { if (input.isNullAt(ordinal)) { null } else { @@ -76,43 +67,21 @@ abstract class AbstractBoundReference extends LeafExpression with NamedExpressio override def exprId: ExprId = throw new UnsupportedOperationException override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - ctx.join = join val javaType = ctx.javaType(dataType) - val value = ctx.getValue(genCodeInput, dataType, ordinal.toString) + val value = ctx.getValue("i", dataType, ordinal.toString) s""" - boolean ${ev.isNull} = $genCodeInput.isNullAt($ordinal); + boolean ${ev.isNull} = i.isNullAt($ordinal); $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """ } } -case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends AbstractBoundReference - -case class LeftBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends AbstractBoundReference { - override protected def join = true - override protected def prefix = "left" - override protected def genCodeInput = "left" - override protected def unwrap(input: InternalRow): InternalRow = - input.asInstanceOf[JoinedRow].left -} - -case class RightBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends AbstractBoundReference { - override protected def join = true - override protected def prefix = "right" - override protected def genCodeInput = "right" - override protected def unwrap(input: InternalRow): InternalRow = - input.asInstanceOf[JoinedRow].right -} - object BindReferences extends Logging { def bindReference[A <: Expression]( - expression: A, - input: Seq[Attribute], - allowFailures: Boolean = false): A = { + expression: A, + input: Seq[Attribute], + allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) @@ -128,30 +97,4 @@ object BindReferences extends Logging { } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } - - def createJoinReferenceMap(left: Seq[Attribute], right: Seq[Attribute]): - Map[ExprId, AbstractBoundReference] = { - (left.zipWithIndex.map { - case (e, ordinal) => - (e.exprId, LeftBoundReference(ordinal, e.dataType, e.nullable)) - } ++ right.zipWithIndex.map { - case (e, ordinal) => - (e.exprId, RightBoundReference(ordinal, e.dataType, e.nullable)) - }).toMap - } - - def bindJoinReferences( - expressions: Seq[Expression], - left: Seq[Attribute], - right: Seq[Attribute]): Seq[Expression] = { - val refMap = createJoinReferenceMap(left, right) - expressions.map { expression => - expression.transform { case a: AttributeReference => - attachTree(a, "Binding attribute") { - refMap.getOrElse(a.exprId, sys.error(s"Couldn't find $a in left " + - s"${left.mkString("[", ",", "]")} or right ${right.mkString("[", ",", "]")}")) - } - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 3e8ac83df8ed..4871d07035e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -822,21 +822,13 @@ private[execution] object AggregateProcessor { // Create the projections. val initialProjection = newMutableProjection(initialValues, Nil)() + val updateProjection = newMutableProjection(updateExpressions, bufferSchema ++ inputSchema)() val evaluateProjection = newMutableProjection(evaluateExpressions, bufferSchema)() - // (EXPERI)-MENTAL - val boundUpdateExpressions = BindReferences.bindJoinReferences( - updateExpressions, bufferSchema, inputSchema) - //val updateProjection = newMutableProjection(updateExpressions, bufferSchema ++ inputSchema)() - val updateProjection = newMutableProjection(boundUpdateExpressions, Nil)() - val join = new JoinedRow - //val join = new JRow(bufferSchema.size, bufferSchema.size + inputSchema.size) - // Create the processor new AggregateProcessor(bufferSchema.toArray, initialProjection, updateProjection, evaluateProjection, aggregates2.toArray, aggregates2OutputOffsets.toArray, - aggregates1.toArray, aggregates1BufferOffsets.toArray, aggregates1OutputOffsets.toArray, - join) + aggregates1.toArray, aggregates1BufferOffsets.toArray, aggregates1OutputOffsets.toArray) } } @@ -853,10 +845,9 @@ private[execution] final class AggregateProcessor( private[this] val aggregates2OutputOffsets: Array[Int], private[this] val aggregates1: Array[AggregateExpression1], private[this] val aggregates1BufferOffsets: Array[Int], - private[this] val aggregates1OutputOffsets: Array[Int], - private[this] val join: JoinedRow) { + private[this] val aggregates1OutputOffsets: Array[Int]) { - //private[this] val join = new JoinedRow + private[this] val join = new JoinedRow private[this] val bufferDataTypes = bufferSchema.toSeq.map(_.dataType) private[this] val aggregates2Size = aggregates2.length private[this] val aggregates1Size = aggregates1.length @@ -938,132 +929,3 @@ private[execution] final class OffsetMutableRow(offset: Int, delegate: MutableRo new OffsetMutableRow(offset, new GenericMutableRow(values)) } } - - -final class JRow(private[this] val numLeftFields: Int, val numFields: Int) extends InternalRow { - - /** - * Determine the index of the which row maps to the given ordinal. This method has been - * implemented using bitwise operations in order to prevent expensive branching. - * - * The key idea here is that the given ordinal is subtracted by the number of fields in the left - * row, and we then use the sign of that calculation to determine the index of the row. - * - * @param i ordinal to find the row for. - * @return index of the row that belongs to the given ordinal. - */ - @inline - private[this] def row(i:Int) = (((i - numLeftFields) & -0x80000000) >>> 31) ^ 1 - - /** Determine the row ordinal given the row index and the input ordinal. */ - @inline - private[this] def ordinal(i:Int, row: Int) = i - row * numLeftFields - - private[this] val rows = new Array[InternalRow](2) - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - rows(0) = r1 - rows(1) = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - rows(0) = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - rows(1) = newRight - this - } - - override def toSeq: Seq[Any] = rows.flatMap(_.toSeq) - - override def getUTF8String(i: Int): UTF8String = { - val r = row(i) - rows(r).getUTF8String(ordinal(i, r)) - } - - override def getBinary(i: Int): Array[Byte] = { - val r = row(i) - rows(r).getBinary(ordinal(i, r)) - } - - override def genericGet(i: Int): Any = { - val r = row(i) - rows(r).genericGet(ordinal(i, r)) - } - - override def isNullAt(i: Int): Boolean = { - val r = row(i) - rows(r).isNullAt(ordinal(i, r)) - } - - override def getInt(i: Int): Int = { - val r = row(i) - rows(r).getInt(ordinal(i, r)) - } - - override def getLong(i: Int): Long = { - val r = row(i) - rows(r).getLong(ordinal(i, r)) - } - - override def getDouble(i: Int): Double = { - val r = row(i) - rows(r).getDouble(ordinal(i, r)) - } - - override def getBoolean(i: Int): Boolean = { - val r = row(i) - rows(r).getBoolean(ordinal(i, r)) - } - - override def getShort(i: Int): Short = { - val r = row(i) - rows(r).getShort(ordinal(i, r)) - } - - override def getByte(i: Int): Byte = { - val r = row(i) - rows(r).getByte(ordinal(i, r)) - } - - override def getFloat(i: Int): Float = { - val r = row(i) - rows(r).getFloat(ordinal(i, r)) - } - - override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { - val r = row(i) - rows(r).getDecimal(ordinal(i, r), precision, scale) - } - - override def getStruct(i: Int, numFields: Int): InternalRow = { - val r = row(i) - rows(r).getStruct(ordinal(i, r), numFields) - } - - override def copy(): InternalRow = { - val row = new JRow(numLeftFields, numFields) - row(rows(0).copy(), rows(1).copy()) - row - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - val Array(row1, row2) = rows - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} From e3091c6e691b9a1f433d4073aee2642ccd469201 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 5 Aug 2015 23:29:23 -0400 Subject: [PATCH 16/19] Rebase/Revert Performance Tweaks --- .../sql/catalyst/analysis/Analyzer.scala | 59 +------------------ .../apache/spark/sql/execution/Window.scala | 21 ++++--- 2 files changed, 11 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 46abf8d0e1b9..f5246acece99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -799,7 +799,7 @@ class Analyzer( // Extract Windowed AggregateExpression case we @ WindowExpression(agg: AggregateExpression, spec: WindowSpecDefinition) => val newAggChildren = agg.children.map(extractExpr) - val newAgg = agg.withNewChildren(newAggChildren) + val newAgg = agg.withNewChildren(newAggChildren).asInstanceOf[AggregateExpression] seenWindowAggregates += newAgg WindowExpression(newAgg, spec) @@ -981,63 +981,6 @@ class Analyzer( } } - /** - * Removes all still-need-evaluate ordering expressions from sort and use an inner project to - * materialize them, finally use a outer project to project them away to keep the result same. - * Then we can make sure we only sort by [[AttributeReference]]s. - * - * As an example, - * {{{ - * Sort('a, 'b + 1, - * Relation('a, 'b)) - * }}} - * will be turned into: - * {{{ - * Project('a, 'b, - * Sort('a, '_sortCondition, - * Project('a, 'b, ('b + 1).as("_sortCondition"), - * Relation('a, 'b)))) - * }}} - */ - object RemoveEvaluationFromSort extends Rule[LogicalPlan] { - private def hasAlias(expr: Expression) = { - expr.find { - case a: Alias => true - case _ => false - }.isDefined - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // The ordering expressions have no effect to the output schema of `Sort`, - // so `Alias`s in ordering expressions are unnecessary and we should remove them. - case s@Sort(ordering, _, _) if ordering.exists(hasAlias) => - val newOrdering = ordering.map(_.transformUp { - case Alias(child, _) => child - }.asInstanceOf[SortOrder]) - s.copy(order = newOrdering) - - case s@Sort(ordering, global, child) - if s.expressions.forall(_.resolved) && s.childrenResolved && !s.hasNoEvaluation => - - val (ref, needEval) = ordering.partition(_.child.isInstanceOf[AttributeReference]) - - val namedExpr = needEval.map(_.child match { - case n: NamedExpression => n - case e => Alias(e, "_sortCondition")() - }) - - val newOrdering = ref ++ needEval.zip(namedExpr).map { case (order, ne) => - order.copy(child = ne.toAttribute) - } - - // Add still-need-evaluate ordering expressions into inner project and then project - // them away after the sort. - Project(child.output, - Sort(newOrdering, global, - Project(child.output ++ namedExpr, child))) - } - } - /** * Check and add proper window frames for all window functions. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 4871d07035e0..e80ea2e600a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -916,16 +916,15 @@ private[execution] final class OffsetMutableRow(offset: Int, delegate: MutableRo extends MutableRow { def setNullAt(i: Int): Unit = delegate.setNullAt(i + offset) def update(i: Int, value: Any): Unit = delegate.update(i + offset, value) - def genericGet(i: Int): Any = delegate.genericGet(i + offset) + override def setBoolean(i: Int, value: Boolean): Unit = delegate.setBoolean(i, value) + override def setByte(i: Int, value: Byte): Unit = delegate.setByte(i, value) + override def setShort(i: Int, value: Short): Unit = delegate.setShort(i, value) + override def setInt(i: Int, value: Int): Unit = delegate.setInt(i, value) + override def setLong(i: Int, value: Long): Unit = delegate.setLong(i, value) + override def setFloat(i: Int, value: Float): Unit = delegate.setFloat(i, value) + override def setDouble(i: Int, value: Double): Unit = delegate.setDouble(i, value) + override def setDecimal(i: Int, value: Decimal, precision: Int): Unit = + delegate.setDecimal(i, value, precision) def numFields: Int = delegate.numFields - offset - def copy(): InternalRow = { - val numFields = delegate.numFields - val values = new Array[Any](numFields) - var i = 0 - while (i < numFields) { - values(i) = delegate.genericGet(i) - i += 1 - } - new OffsetMutableRow(offset, new GenericMutableRow(values)) - } + def copy(): InternalRow = this } From 6836e4b9f8c15b207a185fd5ca08483c97b2dd6c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 5 Aug 2015 23:33:53 -0400 Subject: [PATCH 17/19] Revert change to BoundAttribute. --- .../spark/sql/catalyst/expressions/BoundAttribute.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 61e23ab8019e..473b9b787058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -79,9 +79,9 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) object BindReferences extends Logging { def bindReference[A <: Expression]( - expression: A, - input: Seq[Attribute], - allowFailures: Boolean = false): A = { + expression: A, + input: Seq[Attribute], + allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) From 7deee3857127b4754ffdf00439a8fbc14059db00 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 10 Aug 2015 18:19:04 -0400 Subject: [PATCH 18/19] Rebase. Removed OffsetRow. Replaced CompactBuffer with ArrayBuffer. --- .../apache/spark/sql/execution/Window.scala | 94 ++++++++----------- 1 file changed, 40 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index e80ea2e600a5..2dda69f84121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -22,11 +22,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.{Decimal, DataType, NullType, IntegerType} +import org.apache.spark.sql.types._ import org.apache.spark.rdd.RDD -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.collection.CompactBuffer -import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap} /** * :: DeveloperApi :: @@ -171,11 +169,8 @@ case class Window( frame: (Char, WindowFrame), functions: Array[Expression], ordinal: Int, - result: MutableRow, + target: MutableRow, size: MutableLiteral): WindowFunctionFrame = { - // Construct the target row. - val target = if (ordinal == 0) result - else new OffsetMutableRow(ordinal, result) // Construct an aggregate processor if we have to. def processor = { @@ -183,7 +178,7 @@ case class Window( case f: SizeBasedWindowFunction => f.withSize(size) case f => f } - AggregateProcessor(prepared, child.output, newMutableProjection) + AggregateProcessor(prepared, ordinal, child.output, newMutableProjection) } // Create the frame processor. @@ -193,7 +188,13 @@ case class Window( FrameBoundaryExtractor(l), FrameBoundaryExtractor(h))) if l == h => - new OffsetWindowFunctionFrame(target, functions, child.output, newMutableProjection, l) + new OffsetWindowFunctionFrame( + target, + ordinal, + functions, + child.output, + newMutableProjection, + l) // Growing Frame. case ('A', SpecifiedWindowFrame(frameType, @@ -270,7 +271,7 @@ case class Window( // function result buffer. val factories = Array.ofDim[(MutableRow, MutableLiteral) => WindowFunctionFrame](framedWindowExprs.size) - val unboundExpressions = mutable.Buffer.empty[Expression] + val unboundExpressions = Buffer.empty[Expression] framedWindowExprs.zipWithIndex.foreach { case ((frame, unboundFrameExpressions), index) => // Track the ordinal. @@ -312,7 +313,7 @@ case class Window( fetchNextRow() // Manage the current partition. - var rows: CompactBuffer[InternalRow] = _ + val rows = ArrayBuffer.empty[InternalRow] val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) val partitionSize = MutableLiteral(0, IntegerType, nullable = false) val frames: Array[WindowFunctionFrame] = factories.map{ f => @@ -322,7 +323,7 @@ case class Window( private[this] def fetchNextPartition() { // Collect all the rows in the current partition. val currentGroup = nextGroup - rows = new CompactBuffer + rows.clear() while (nextRowAvailable && nextGroup == currentGroup) { rows += nextRow.copy() fetchNextRow() @@ -413,7 +414,7 @@ private[execution] abstract class WindowFunctionFrame { * * @param rows to calculate the frame results for. */ - def prepare(rows: CompactBuffer[InternalRow]): Unit + def prepare(rows: ArrayBuffer[InternalRow]): Unit /** * Write the current results to the target row. @@ -432,13 +433,14 @@ private[execution] abstract class WindowFunctionFrame { */ private[execution] final class OffsetWindowFunctionFrame( target: MutableRow, + ordinal: Int, expressions: Array[Expression], inputSchema: Seq[Attribute], newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, offset: Int) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the input row currently used for output. */ private[this] var inputIndex = 0 @@ -458,7 +460,7 @@ private[execution] final class OffsetWindowFunctionFrame( val defaultInputSchema = inputSchema.map(_.newInstance()) ++ inputSchema // Collect the expressions and bind them. - val boundExpressions = expressions.toSeq.map { + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { case e: OffsetWindowFunction => val input = BindReferences.bindReference(e.input, inputSchema) if (e.default == null || e.default.foldable && e.default.eval() == null) { @@ -477,7 +479,7 @@ private[execution] final class OffsetWindowFunctionFrame( newMutableProjection(boundExpressions, Nil)().target(target) } - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputIndex = offset outputIndex = 0 @@ -512,7 +514,7 @@ private[execution] final class SlidingWindowFunctionFrame( ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value greater than the upper bound of the current * output row. */ @@ -526,7 +528,7 @@ private[execution] final class SlidingWindowFunctionFrame( private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputHighIndex = 0 inputLowIndex = 0 @@ -584,7 +586,7 @@ private[execution] final class UnboundedWindowFunctionFrame( private[this] var status: MutableRow = _ /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { status = processor.initialize processor.update(status, rows, 0, rows.size) } @@ -616,7 +618,7 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value greater than the upper bound of the current * output row. */ @@ -629,7 +631,7 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( private[this] var status: MutableRow = _ /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputIndex = 0 outputIndex = 0 @@ -680,7 +682,7 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( lbound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value equal to or greater than the lower bound of the * current output row. */ @@ -690,7 +692,7 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputIndex = 0 outputIndex = 0 @@ -744,18 +746,19 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( */ private[execution] object AggregateProcessor { def apply(functions: Array[Expression], + ordinal: Int, inputSchema: Seq[Attribute], newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): AggregateProcessor = { - val bufferSchema = mutable.Buffer.empty[AttributeReference] - val initialValues = mutable.Buffer.empty[Expression] - val updateExpressions = mutable.Buffer.empty[Expression] - val evaluateExpressions = mutable.Buffer.empty[Expression] - val aggregates1 = mutable.Buffer.empty[AggregateExpression1] - val aggregates1BufferOffsets = mutable.Buffer.empty[Int] - val aggregates1OutputOffsets = mutable.Buffer.empty[Int] - val aggregates2 = mutable.Buffer.empty[AggregateFunction2] - val aggregates2OutputOffsets = mutable.Buffer.empty[Int] + val bufferSchema = Buffer.empty[AttributeReference] + val initialValues = Buffer.empty[Expression] + val updateExpressions = Buffer.empty[Expression] + val evaluateExpressions = Buffer.fill[Expression](ordinal)(NoOp) + val aggregates1 = Buffer.empty[AggregateExpression1] + val aggregates1BufferOffsets = Buffer.empty[Int] + val aggregates1OutputOffsets = Buffer.empty[Int] + val aggregates2 = Buffer.empty[AggregateFunction2] + val aggregates2OutputOffsets = Buffer.empty[Int] // Flatten AggregateExpression2's val flattened = functions.zipWithIndex.map { @@ -764,7 +767,7 @@ private[execution] object AggregateProcessor { } // Add distinct evaluation path. - val distinctExpressionSchemaMap = mutable.HashMap.empty[Seq[Expression], AttributeReference] + val distinctExpressionSchemaMap = HashMap.empty[Seq[Expression], AttributeReference] flattened.filter(_._2).foreach { case (af2, _, _) => // TODO cannocalize expressions? @@ -800,7 +803,7 @@ private[execution] object AggregateProcessor { case (agg: AggregateFunction2, false, i) => val boundAgg = BindReferences.bindReference(agg, inputSchema) aggregates2 += boundAgg - aggregates2OutputOffsets += i + aggregates2OutputOffsets += (i + ordinal) agg.withNewMutableBufferOffset(bufferSchema.size) bufferSchema ++= boundAgg.bufferAttributes val nops = Seq.fill(boundAgg.bufferAttributes.size)(NoOp) @@ -810,7 +813,7 @@ private[execution] object AggregateProcessor { case (agg: AggregateExpression1, false, i) => aggregates1 += BindReferences.bindReference(agg, inputSchema) aggregates1BufferOffsets += bufferSchema.size - aggregates1OutputOffsets += i + aggregates1OutputOffsets += (i + ordinal) // TODO typing - we would need to create UDT for this. bufferSchema += AttributeReference("agg", NullType, nullable = false)() initialValues += NoOp @@ -885,7 +888,7 @@ private[execution] final class AggregateProcessor( } /** Bulk update the given buffer. */ - def update(buffer: MutableRow, input: CompactBuffer[InternalRow], begin: Int, end: Int): Unit = { + def update(buffer: MutableRow, input: ArrayBuffer[InternalRow], begin: Int, end: Int): Unit = { var i = begin while (i < end) { update(buffer, input(i)) @@ -911,20 +914,3 @@ private[execution] final class AggregateProcessor( } } } - -private[execution] final class OffsetMutableRow(offset: Int, delegate: MutableRow) - extends MutableRow { - def setNullAt(i: Int): Unit = delegate.setNullAt(i + offset) - def update(i: Int, value: Any): Unit = delegate.update(i + offset, value) - override def setBoolean(i: Int, value: Boolean): Unit = delegate.setBoolean(i, value) - override def setByte(i: Int, value: Byte): Unit = delegate.setByte(i, value) - override def setShort(i: Int, value: Short): Unit = delegate.setShort(i, value) - override def setInt(i: Int, value: Int): Unit = delegate.setInt(i, value) - override def setLong(i: Int, value: Long): Unit = delegate.setLong(i, value) - override def setFloat(i: Int, value: Float): Unit = delegate.setFloat(i, value) - override def setDouble(i: Int, value: Double): Unit = delegate.setDouble(i, value) - override def setDecimal(i: Int, value: Decimal, precision: Int): Unit = - delegate.setDecimal(i, value, precision) - def numFields: Int = delegate.numFields - offset - def copy(): InternalRow = this -} From c579ca6761d8dee0385f2efe14fd4e408b1248a0 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 10 Aug 2015 18:27:17 -0400 Subject: [PATCH 19/19] Removed left over from performance tuning. Cleanup some comments. --- .../expressions/codegen/CodeGenerator.scala | 16 -- .../codegen/GenerateMutableProjection.scala | 1 - .../codegen/GenerateProjection.scala | 1 - .../codegen/GenerateUnsafeProjection.scala | 149 ------------------ .../expressions/windowExpressions.scala | 5 - 5 files changed, 172 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1c9d43d07552..7b41c9a3f3b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -90,22 +90,6 @@ class CodeGenContext { addedFuntions += ((funcName, funcCode)) } - /** - * Set to true when we are expecting a JoinedRow as input. - */ - var join: Boolean = false - def unwrapJoinRow: String = { - if (join) { - // We could also add the left()/right() methods to the InternalRow interface. - """ - |org.apache.spark.sql.catalyst.expressions.JoinedRow join = (org.apache.spark.sql.catalyst.expressions.JoinedRow) i; - |InternalRow left = join.left(); - |InternalRow right = join.right(); - """.stripMargin - } - else "" - } - final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index dc49694594e5..ac58423cd884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -127,7 +127,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - ${ctx.unwrapJoinRow} $projectionCalls return mutableRow; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 18951449b401..c744e84d822e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -179,7 +179,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { $columns public SpecificRow(InternalRow i) { - ${ctx.unwrapJoinRow} $initColumns } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 55377abc1560..d8912df694a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -116,155 +116,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") } - /** - * Generates the code to create an [[UnsafeRow]] object based on the input expressions. - * @param ctx context for code generation - * @param ev specifies the name of the variable for the output [[UnsafeRow]] object - * @param expressions input expressions - * @return generated code to put the expression output into an [[UnsafeRow]] - */ - def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression]) - : String = { - - val ret = ev.primitive - ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val cursor = ctx.freshName("cursor") - val numBytes = ctx.freshName("numBytes") - - val exprs = expressions.map { e => e.dataType match { - case st: StructType => createCodeForStruct(ctx, e.gen(ctx), st) - case _ => e.gen(ctx) - }} - val allExprs = exprs.map(_.code).mkString("\n") - - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = expressions.zipWithIndex.map { - case (e, i) => genAdditionalSize(e.dataType, exprs(i)) - }.mkString("") - - val writers = expressions.zipWithIndex.map { case (e, i) => - val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor) - s"""if (${exprs(i).isNull}) { - $ret.setNullAt($i); - } else { - $update; - }""" - }.mkString("\n ") - - s""" - ${ctx.unwrapJoinRow} - $allExprs - int $numBytes = $fixedSize $additionalSize; - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - - $ret.pointTo( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, - ${expressions.size}, - $numBytes); - int $cursor = $fixedSize; - - $writers - boolean ${ev.isNull} = false; - """ - } - - /** - * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. - * - * This function also handles nested structs by recursively generating the code to do conversion. - * - * @param ctx code generation context - * @param input the input struct, identified by a [[GeneratedExpressionCode]] - * @param schema schema of the struct field - */ - // TODO: refactor createCode and this function to reduce code duplication. - private def createCodeForStruct( - ctx: CodeGenContext, - input: GeneratedExpressionCode, - schema: StructType): GeneratedExpressionCode = { - - val isNull = input.isNull - val primitive = ctx.freshName("structConvert") - ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val cursor = ctx.freshName("cursor") - - val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { - case (dt, i) => dt match { - case st: StructType => - val nestedStructEv = GeneratedExpressionCode( - code = "", - isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" - ) - createCodeForStruct(ctx, nestedStructEv, st) - case _ => - GeneratedExpressionCode( - code = "", - isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" - ) - } - } - val allExprs = exprs.map(_.code).mkString("\n") - - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => - genAdditionalSize(dt, ev) - }.mkString("") - - val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => - val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor) - s""" - if (${exprs(i).isNull}) { - $primitive.setNullAt($i); - } else { - $update; - } - """ - }.mkString("\n ") - - // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, - // just copy the bytes directly into our buffer space without running any conversion. - // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from - // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow. - val tmp = ctx.freshName("tmp") - val numBytes = ctx.freshName("numBytes") - val code = s""" - |${input.code} - |if (!${input.isNull}) { - | Object $tmp = (Object) ${input.primitive}; - | if ($tmp instanceof UnsafeRow) { - | $primitive = (UnsafeRow) $tmp; - | } else { - | $allExprs - | - | int $numBytes = $fixedSize $additionalSize; - | if ($numBytes > $buffer.length) { - | $buffer = new byte[$numBytes]; - | } - | - | $primitive.pointTo( - | $buffer, - | $PlatformDependent.BYTE_ARRAY_OFFSET, - | ${exprs.size}, - | $numBytes); - | int $cursor = $fixedSize; - | - | $writers - | } - |} - """.stripMargin - - GeneratedExpressionCode(code, isNull, primitive) - } - /** * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 815d0403d70e..cb34a6f2337a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -422,7 +422,6 @@ case class RowNumber() extends RowNumberLike { override val evaluateExpression = Cast(rowNumber, IntegerType) } -// TODO check if this works in combination with CodeGeneration? case class CumeDist(n: Expression) extends RowNumberLike with SizeBasedWindowFunction { def this() = this(Literal(0)) override def dataType: DataType = DoubleType @@ -432,9 +431,6 @@ case class CumeDist(n: Expression) extends RowNumberLike with SizeBasedWindowFun override val evaluateExpression = Cast(rowNumber, DoubleType) / Cast(n, DoubleType) } -// TODO check if this works in combination with CodeGeneration? -// TODO check logic -// Check serialization case class NTile(buckets: Expression, n: Expression) extends RowNumberLike with SizeBasedWindowFunction { def this() = this(Literal(1), Literal(0)) @@ -513,7 +509,6 @@ case class DenseRank(order: Seq[Expression]) extends RankLike { override val initialValues = Literal(0) +: orderInit } -// TODO check if this works in combination with CodeGeneration? case class PercentRank(order: Seq[Expression], n: Expression) extends RankLike with SizeBasedWindowFunction { def this() = this(Nil, MutableLiteral(0, IntegerType))