Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix merging error.
  • Loading branch information
viirya committed Oct 22, 2015
commit d3e441457f8c0243170fa2f6a8408c0c1ed6bc99
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,23 @@ case class Count(child: Expression) extends DeclarativeAggregate {
override val evaluateExpression = Cast(currentCount, LongType)
}

case class First(child: Expression) extends DeclarativeAggregate {
/**
* Returns the first value of `child` for a group of rows. If the first value of `child`
* is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already
* sorted column, if we do partial aggregation and final aggregation (when mergeExpression
* is used) its result will not be deterministic (unless the input table is sorted and has
* a single partition, and we use a single reducer to do the aggregation.).
* @param child
*/
case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))

private val ignoreNulls: Boolean = ignoreNullsExpr match {
case Literal(b: Boolean, BooleanType) => b
case _ =>
throw new AnalysisException("The second argument of First should be a boolean literal.")
}

override def children: Seq[Expression] = child :: Nil

Expand All @@ -138,24 +154,61 @@ case class First(child: Expression) extends DeclarativeAggregate {

private val first = AttributeReference("first", child.dataType)()

override val aggBufferAttributes = first :: Nil
private val valueSet = AttributeReference("valueSet", BooleanType)()

override val aggBufferAttributes = first :: valueSet :: Nil

override val initialValues = Seq(
/* first = */ Literal.create(null, child.dataType)
/* first = */ Literal.create(null, child.dataType),
/* valueSet = */ Literal.create(false, BooleanType)
)

override val updateExpressions = Seq(
/* first = */ If(IsNull(first), child, first)
)
override val updateExpressions = {
if (ignoreNulls) {
Seq(
/* first = */ If(Or(valueSet, IsNull(child)), first, child),
/* valueSet = */ Or(valueSet, IsNotNull(child))
)
} else {
Seq(
/* first = */ If(valueSet, first, child),
/* valueSet = */ Literal.create(true, BooleanType)
)
}
}

override val mergeExpressions = Seq(
/* first = */ If(IsNull(first.left), first.right, first.left)
)
override val mergeExpressions = {
// For first, we can just check if valueSet.left is set to true. If it is set
// to true, we use first.right. If not, we use first.right (even if valueSet.right is
// false, we are safe to do so because first.right will be null in this case).
Seq(
/* first = */ If(valueSet.left, first.left, first.right),
/* valueSet = */ Or(valueSet.left, valueSet.right)
)
}

override val evaluateExpression = first

override def toString: String = s"FIRST($child)${if (ignoreNulls) " IGNORE NULLS"}"
}

case class Last(child: Expression) extends DeclarativeAggregate {
/**
* Returns the last value of `child` for a group of rows. If the last value of `child`
* is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already
* sorted column, if we do partial aggregation and final aggregation (when mergeExpression
* is used) its result will not be deterministic (unless the input table is sorted and has
* a single partition, and we use a single reducer to do the aggregation.).
* @param child
*/
case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))

private val ignoreNulls: Boolean = ignoreNullsExpr match {
case Literal(b: Boolean, BooleanType) => b
case _ =>
throw new AnalysisException("The second argument of First should be a boolean literal.")
}

override def children: Seq[Expression] = child :: Nil

Expand All @@ -178,15 +231,33 @@ case class Last(child: Expression) extends DeclarativeAggregate {
/* last = */ Literal.create(null, child.dataType)
)

override val updateExpressions = Seq(
/* last = */ If(IsNull(child), last, child)
)
override val updateExpressions = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(child), last, child)
)
} else {
Seq(
/* last = */ child
)
}
}

override val mergeExpressions = Seq(
/* last = */ If(IsNull(last.right), last.left, last.right)
)
override val mergeExpressions = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(last.right), last.left, last.right)
)
} else {
Seq(
/* last = */ last.right
)
}
}

override val evaluateExpression = last

override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}"
}

case class Max(child: Expression) extends DeclarativeAggregate {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ object Utils {
mode = aggregate.Complete,
isDistinct = true)

case expressions.First(child) =>
case expressions.First(child, ignoreNulls) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.First(child),
aggregateFunction = aggregate.First(child, ignoreNulls),
mode = aggregate.Complete,
isDistinct = false)

case expressions.Last(child) =>
case expressions.Last(child, ignoreNulls) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Last(child),
aggregateFunction = aggregate.Last(child, ignoreNulls),
mode = aggregate.Complete,
isDistinct = false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import com.clearspring.analytics.stream.cardinality.HyperLogLog

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
Expand Down Expand Up @@ -630,59 +631,113 @@ case class CombineSetsAndSumFunction(
}
}

case class First(child: Expression) extends UnaryExpression with PartialAggregate1 {
case class First(
child: Expression,
ignoreNullsExpr: Expression)
extends UnaryExpression with PartialAggregate1 {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))

private val ignoreNulls: Boolean = ignoreNullsExpr match {
case Literal(b: Boolean, BooleanType) => b
case _ =>
throw new AnalysisException("The second argument of First should be a boolean literal.")
}

override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"FIRST($child)"
override def toString: String = s"FIRST(${child}${if (ignoreNulls) " IGNORE NULLS"})"

override def asPartial: SplitEvaluation = {
val partialFirst = Alias(First(child), "PartialFirst")()
val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")()
SplitEvaluation(
First(partialFirst.toAttribute),
First(partialFirst.toAttribute, ignoreNulls),
partialFirst :: Nil)
}
override def newInstance(): FirstFunction = new FirstFunction(child, this)
override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this)
}

case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
object First {
def apply(child: Expression): First = First(child, ignoreNulls = false)

var result: Any = null
def apply(child: Expression, ignoreNulls: Boolean): First =
First(child, Literal.create(ignoreNulls, BooleanType))
}

case class FirstFunction(
expr: Expression,
ignoreNulls: Boolean,
base: AggregateExpression1)
extends AggregateFunction1 {

def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.

private[this] var result: Any = null

private[this] var valueSet: Boolean = false

override def update(input: InternalRow): Unit = {
// We ignore null values.
if (result == null) {
result = expr.eval(input)
if (!valueSet) {
val value = expr.eval(input)
// When we have not set the result, we will set the result if we respect nulls
// (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null.
if (!ignoreNulls || (ignoreNulls && value != null)) {
result = value
valueSet = true
}
}
}

override def eval(input: InternalRow): Any = result
}

case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 {
case class Last(
child: Expression,
ignoreNullsExpr: Expression)
extends UnaryExpression with PartialAggregate1 {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))

private val ignoreNulls: Boolean = ignoreNullsExpr match {
case Literal(b: Boolean, BooleanType) => b
case _ =>
throw new AnalysisException("The second argument of First should be a boolean literal.")
}

override def references: AttributeSet = child.references
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"LAST($child)"
override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}"

override def asPartial: SplitEvaluation = {
val partialLast = Alias(Last(child), "PartialLast")()
val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")()
SplitEvaluation(
Last(partialLast.toAttribute),
Last(partialLast.toAttribute, ignoreNulls),
partialLast :: Nil)
}
override def newInstance(): LastFunction = new LastFunction(child, this)
override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this)
}

case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
object Last {
def apply(child: Expression): Last = Last(child, ignoreNulls = false)

def apply(child: Expression, ignoreNulls: Boolean): Last =
Last(child, Literal.create(ignoreNulls, BooleanType))
}

case class LastFunction(
expr: Expression,
ignoreNulls: Boolean,
base: AggregateExpression1)
extends AggregateFunction1 {

def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.

var result: Any = null

override def update(input: InternalRow): Unit = {
val value = expr.eval(input)
// We ignore null values.
if (value != null) {
if (!ignoreNulls || (ignoreNulls && value != null)) {
result = value
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.expressions

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.{Column, catalyst}
import org.apache.spark.sql.catalyst.expressions._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,44 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(11.125) :: Nil)
}

test("first_value and last_value") {
// We force to use a single partition for the sort and aggregate to make result
// deterministic.
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer(
sqlContext.sql(
"""
|SELECT
| first_valUE(key),
| lasT_value(key),
| firSt(key),
| lASt(key),
| first_valUE(key, true),
| lasT_value(key, true),
| firSt(key, true),
| lASt(key, true)
|FROM (SELECT key FROM agg1 ORDER BY key) tmp
""".stripMargin),
Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil)

checkAnswer(
sqlContext.sql(
"""
|SELECT
| first_valUE(key),
| lasT_value(key),
| firSt(key),
| lASt(key),
| first_valUE(key, true),
| lasT_value(key, true),
| firSt(key, true),
| lASt(key, true)
|FROM (SELECT key FROM agg1 ORDER BY key DESC) tmp
""".stripMargin),
Row(3, null, 3, null, 3, 1, 3, 1) :: Nil)
}
}

test("udaf") {
checkAnswer(
sqlContext.sql(
Expand Down