Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.types._

/**
* :: DeveloperApi ::
* Cells used to support aggregating over nested fields.
* @param child the input data source.
*/
case class SumCell(child: Expression) extends UnaryExpression{
type EvaluatedType = Any

override def eval(input: Row): Any = {
val evalE = child.eval(input)
evalE match {
case seq: Seq[Any] => seq.reduce((a, b) => numeric.plus(a, b))
case _ => evalE
}
}

override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable

lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
case other => sys.error(s"Type $other does not support numeric operations")
}

override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case ArrayType(dataType, _) =>
dataType
case _ =>
child.dataType
}
}

case class CountCell(child: Expression) extends UnaryExpression{
type EvaluatedType = Any

override def eval(input: Row): Any = {
val evalE = child.eval(input)
evalE match {
case seq: Seq[Any] => seq.size.toLong
case p if p != null => 1L
case _ => null
}
}

override def nullable: Boolean = false
override def dataType: DataType = LongType
}

case class MinCell(child: Expression) extends UnaryExpression{
type EvaluatedType = Any

override def eval(input: Row): Any = {
val evalE = child.eval(input)
evalE match {
case seq: Seq[Any] => seq.reduce((a, b) => if (ordering.compare(a, b) < 0) a else b)
case _ => evalE
}
}

override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable

lazy val ordering = dataType match {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case other => sys.error(s"Type $other does not support ordered operations")
}

override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case ArrayType(dataType, _) =>
dataType
case _ =>
child.dataType
}
}

case class MaxCell(child: Expression) extends UnaryExpression{
type EvaluatedType = Any

override def eval(input: Row): Any = {
val evalE = child.eval(input)
evalE match {
case seq: Seq[Any] => seq.reduce((a, b) => if (ordering.compare(a, b) > 0) a else b)
case _ => evalE
}
}

lazy val ordering = dataType match {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case other => sys.error(s"Type $other does not support ordered operations")
}

override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable

override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case ArrayType(dataType, _) =>
dataType
case _ =>
child.dataType
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[
override def toString: String = s"MIN($child)"

override def asPartial: SplitEvaluation = {
val partialMin = Alias(Min(child), "PartialMin")()
val partialMin = Alias(Min(MinCell(child)), "PartialMin")()
SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
}

Expand Down Expand Up @@ -128,7 +128,7 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[
override def toString: String = s"MAX($child)"

override def asPartial: SplitEvaluation = {
val partialMax = Alias(Max(child), "PartialMax")()
val partialMax = Alias(Max(MaxCell(child)), "PartialMax")()
SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
}

Expand Down Expand Up @@ -159,7 +159,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def toString: String = s"COUNT($child)"

override def asPartial: SplitEvaluation = {
val partialCount = Alias(Count(child), "PartialCount")()
val partialCount = Alias(Count(CountCell(child)), "PartialCount")()
SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
}

Expand Down Expand Up @@ -328,8 +328,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
child.dataType match {
case DecimalType.Fixed(_, _) | DecimalType.Unlimited =>
// Turn the child to unlimited decimals for calculation, before going back to fixed
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
val partialSum = Alias(Sum(SumCell(Cast(child, DecimalType.Unlimited))), "PartialSum")()
val partialCount = Alias(Count(CountCell(child)), "PartialCount")()

val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
Expand All @@ -338,8 +338,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
partialCount :: partialSum :: Nil)

case _ =>
val partialSum = Alias(Sum(child), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
val partialSum = Alias(Sum(SumCell(child)), "PartialSum")()
val partialCount = Alias(Count(CountCell(child)), "PartialCount")()

val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
Expand Down Expand Up @@ -370,13 +370,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
override def asPartial: SplitEvaluation = {
child.dataType match {
case DecimalType.Fixed(_, _) =>
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
val partialSum = Alias(Sum(SumCell(Cast(child, DecimalType.Unlimited))), "PartialSum")()
SplitEvaluation(
Cast(CombineSum(partialSum.toAttribute), dataType),
partialSum :: Nil)

case _ =>
val partialSum = Alias(Sum(child), "PartialSum")()
val partialSum = Alias(Sum(SumCell(child)), "PartialSum")()
SplitEvaluation(
CombineSum(partialSum.toAttribute),
partialSum :: Nil)
Expand Down Expand Up @@ -560,7 +560,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
count += 1L
count += evaluatedExpr.asInstanceOf[Long]
}
}

Expand Down Expand Up @@ -618,7 +618,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr

private val sum = MutableLiteral(null, calcType)

private val addFunction =
private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))

override def update(input: Row): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,118 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}
""".children

case cell @ SumCell(e) if e.dataType.isInstanceOf[ArrayType] =>
val eval = expressionEvaluator(e)
q"""
..${eval.code}
var $nullTerm = false
var $primitiveTerm: ${termForType(cell.dataType)} = 0
if (${eval.nullTerm}) {
$nullTerm = true
} else {
$primitiveTerm = ${eval.primitiveTerm}
.asInstanceOf[Seq[${termForType(cell.dataType)}]]
.reduce((a,b) => a + b)
}
""".children

case cell @ SumCell(e @ NumericType()) =>
val eval = expressionEvaluator(e)
q"""
..${eval.code}
var $nullTerm = false
var $primitiveTerm: ${termForType(cell.dataType)} = 0
if (${eval.nullTerm}) {
$nullTerm = true
} else {
$primitiveTerm = ${eval.primitiveTerm}
}
""".children

case cell @ CountCell(e) if e.dataType.isInstanceOf[ArrayType] =>
val eval = expressionEvaluator(e)
q"""
..${eval.code}
var $nullTerm = false
var $primitiveTerm: ${termForType(cell.dataType)} = 0
if (${eval.nullTerm}) {
$nullTerm = true
} else {
$primitiveTerm = ${eval.primitiveTerm}
.asInstanceOf[Seq[${termForType(cell.dataType)}]]
.size
}
""".children

case cell @ CountCell(e @ NumericType()) =>
val eval = expressionEvaluator(e)
q"""
..${eval.code}
var $nullTerm = false
var $primitiveTerm: ${termForType(cell.dataType)} = 0
if (${eval.nullTerm}) {
$nullTerm = true
} else {
$primitiveTerm = 1L
}
""".children

case cell @ MinCell(e) if e.dataType.isInstanceOf[ArrayType] =>
val eval = expressionEvaluator(e)
q"""
..${eval.code}
var $nullTerm = false
var $primitiveTerm: ${termForType(cell.dataType)} = 0
if (${eval.nullTerm}) {
$nullTerm = true
} else {
$primitiveTerm = ${eval.primitiveTerm}
.asInstanceOf[Seq[${termForType(cell.dataType)}]]
.min
}
""".children

case cell @ MinCell(e @ NumericType()) =>
val eval = expressionEvaluator(e)
q"""
..${eval.code}
var $nullTerm = false
var $primitiveTerm: ${termForType(cell.dataType)} = 0
if (${eval.nullTerm}) {
$nullTerm = true
} else {
$primitiveTerm = ${eval.primitiveTerm}
}
""".children

case cell @ MaxCell(e) if e.dataType.isInstanceOf[ArrayType] =>
val eval = expressionEvaluator(e)
q"""
..${eval.code}
var $nullTerm = false
var $primitiveTerm: ${termForType(cell.dataType)} = 0
if (${eval.nullTerm}) {
$nullTerm = true
} else {
$primitiveTerm = ${eval.primitiveTerm}
.asInstanceOf[Seq[${termForType(cell.dataType)}]]
.max
}
""".children

case cell @ MaxCell(e @ NumericType()) =>
val eval = expressionEvaluator(e)
q"""
..${eval.code}
var $nullTerm = false
var $primitiveTerm: ${termForType(cell.dataType)} = 0
if (${eval.nullTerm}) {
$nullTerm = true
} else {
$primitiveTerm = ${eval.primitiveTerm}
}
""".children

case IsNotNull(e) =>
val eval = expressionEvaluator(e)
q"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ case class GeneratedAggregate(
}
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
val initialValue = Literal(0L)
val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
val updateFunction = If(IsNotNull(toCount), Add(currentCount, toCount), currentCount)
val result = currentCount

AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
Expand Down
45 changes: 45 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,51 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
}

test("SPARK-7549 Support aggregating over nested fields") {
checkAnswer(sql("SELECT sum(a[0]) FROM complexData2"), Row(3))
checkAnswer(sql("SELECT sum(s.key) FROM complexData2"), Row(6))
checkAnswer(sql("SELECT sum(nestedData[0]) FROM arrayData"), Row(15))

checkAnswer(sql("SELECT count(a[0]) FROM complexData2"), Row(2))
checkAnswer(sql("SELECT count(s.key) FROM complexData2"), Row(4))
checkAnswer(sql("SELECT count(nestedData[0]) FROM arrayData"), Row(6))

checkAnswer(sql("SELECT min(a[0]) FROM complexData2"), Row(1))
checkAnswer(sql("SELECT min(s.key) FROM complexData2"), Row(1))
checkAnswer(sql("SELECT min(nestedData[0]) FROM arrayData"), Row(1))

checkAnswer(sql("SELECT max(a[0]) FROM complexData2"), Row(2))
checkAnswer(sql("SELECT max(s.key) FROM complexData2"), Row(2))
checkAnswer(sql("SELECT max(nestedData[0]) FROM arrayData"), Row(4))

checkAnswer(sql("SELECT avg(a[0]) FROM complexData2"), Row(1.5))
checkAnswer(sql("SELECT avg(s.key) FROM complexData2"), Row(1.5))
checkAnswer(sql("SELECT avg(nestedData[0]) FROM arrayData"), Row(2.5))

val originalValue = conf.codegenEnabled
setConf(SQLConf.CODEGEN_ENABLED, "true")
checkAnswer(sql("SELECT sum(a[0]) FROM complexData2"), Row(3))
checkAnswer(sql("SELECT sum(s.key) FROM complexData2"), Row(6))
checkAnswer(sql("SELECT sum(nestedData[0]) FROM arrayData"), Row(15))

checkAnswer(sql("SELECT count(a[0]) FROM complexData2"), Row(2))
checkAnswer(sql("SELECT count(s.key) FROM complexData2"), Row(4))
checkAnswer(sql("SELECT count(nestedData[0]) FROM arrayData"), Row(6))

checkAnswer(sql("SELECT min(a[0]) FROM complexData2"), Row(1))
checkAnswer(sql("SELECT min(s.key) FROM complexData2"), Row(1))
checkAnswer(sql("SELECT min(nestedData[0]) FROM arrayData"), Row(1))

checkAnswer(sql("SELECT max(a[0]) FROM complexData2"), Row(2))
checkAnswer(sql("SELECT max(s.key) FROM complexData2"), Row(2))
checkAnswer(sql("SELECT max(nestedData[0]) FROM arrayData"), Row(4))

checkAnswer(sql("SELECT avg(a[0]) FROM complexData2"), Row(1.5))
checkAnswer(sql("SELECT avg(s.key) FROM complexData2"), Row(1.5))
checkAnswer(sql("SELECT avg(nestedData[0]) FROM arrayData"), Row(2.5))
setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}

test("SPARK-6898: complete support for special chars in column names") {
jsonRDD(sparkContext.makeRDD(
"""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil))
Expand Down
8 changes: 8 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,12 @@ object TestData {
:: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false)
:: Nil).toDF()
complexData.registerTempTable("complexData")

case class ComplexData2(s: Seq[TestData], a: Seq[Int], b: Boolean)
val complexData2 =
TestSQLContext.sparkContext.parallelize(
ComplexData2(Seq[TestData](TestData(1, "1"), TestData(1, "2")), Seq(1), true)
:: ComplexData2(Seq[TestData](TestData(2, "2"), TestData(2, "3")), Seq(2), false)
:: Nil).toDF()
complexData2.registerTempTable("complexData2")
}