Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Renamed apply to eval for generators and added a bunch of override's.
  • Loading branch information
rxin committed Apr 7, 2014
commit 1a47e10e749f3b13d02da40e07988b1b5c4f7dfd
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType}

Expand Down Expand Up @@ -231,7 +231,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express

override def foldable = left.foldable && right.foldable

def references = left.references ++ right.references
override def references = left.references ++ right.references

override def toString = s"($left $symbol $right)"
}
Expand All @@ -243,5 +243,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>

def references = child.references
override def references = child.references
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ case class SplitEvaluation(
partialEvaluations: Seq[NamedExpression])

/**
* An [[AggregateExpression]] that can be partially computed without seeing all relevent tuples.
* An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples.
* These partial evaluations can then be combined to compute the actual answer.
*/
abstract class PartialAggregate extends AggregateExpression {
Expand All @@ -63,28 +63,28 @@ abstract class AggregateFunction
extends AggregateExpression with Serializable with trees.LeafNode[Expression] {
self: Product =>

type EvaluatedType = Any
override type EvaluatedType = Any

/** Base should return the generic aggregate expression that this function is computing */
val base: AggregateExpression
def references = base.references
def nullable = base.nullable
def dataType = base.dataType
override def references = base.references
override def nullable = base.nullable
override def dataType = base.dataType

def update(input: Row): Unit
override def eval(input: Row): Any

// Do we really need this?
def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
override def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}

case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
def references = child.references
def nullable = false
def dataType = IntegerType
override def references = child.references
override def nullable = false
override def dataType = IntegerType
override def toString = s"COUNT($child)"

def asPartial: SplitEvaluation = {
override def asPartial: SplitEvaluation = {
val partialCount = Alias(Count(child), "PartialCount")()
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
}
Expand All @@ -93,18 +93,18 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
}

case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
def children = expressions
def references = expressions.flatMap(_.references).toSet
def nullable = false
def dataType = IntegerType
override def children = expressions
override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
override def newInstance = new CountDistinctFunction(expressions, this)
}

case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
def references = child.references
def nullable = false
def dataType = DoubleType
override def references = child.references
override def nullable = false
override def dataType = DoubleType
override def toString = s"AVG($child)"

override def asPartial: SplitEvaluation = {
Expand All @@ -122,9 +122,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
}

case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
def references = child.references
def nullable = false
def dataType = child.dataType
override def references = child.references
override def nullable = false
override def dataType = child.dataType
override def toString = s"SUM($child)"

override def asPartial: SplitEvaluation = {
Expand All @@ -140,18 +140,18 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class SumDistinct(child: Expression)
extends AggregateExpression with trees.UnaryNode[Expression] {

def references = child.references
def nullable = false
def dataType = child.dataType
override def references = child.references
override def nullable = false
override def dataType = child.dataType
override def toString = s"SUM(DISTINCT $child)"

override def newInstance = new SumDistinctFunction(child, this)
}

case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
def references = child.references
def nullable = child.nullable
def dataType = child.dataType
override def references = child.references
override def nullable = child.nullable
override def dataType = child.dataType
override def toString = s"FIRST($child)"

override def asPartial: SplitEvaluation = {
Expand All @@ -172,14 +172,12 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow))
private val sumAsDouble = Cast(sum, DoubleType)



private val addFunction = Add(sum, expr)

override def eval(input: Row): Any =
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble

def update(input: Row): Unit = {
override def update(input: Row): Unit = {
count += 1
sum.update(addFunction, input)
}
Expand All @@ -190,7 +188,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag

var count: Int = _

def update(input: Row): Unit = {
override def update(input: Row): Unit = {
val evaluatedExpr = expr.map(_.eval(input))
if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
count += 1
Expand All @@ -207,7 +205,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr

private val addFunction = Add(sum, expr)

def update(input: Row): Unit = {
override def update(input: Row): Unit = {
sum.update(addFunction, input)
}

Expand All @@ -219,9 +217,9 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)

def this() = this(null, null) // Required for serialization.

val seen = new scala.collection.mutable.HashSet[Any]()
private val seen = new scala.collection.mutable.HashSet[Any]()

def update(input: Row): Unit = {
override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
seen += evaluatedExpr
Expand All @@ -239,7 +237,7 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio

val seen = new scala.collection.mutable.HashSet[Any]()

def update(input: Row): Unit = {
override def update(input: Row): Unit = {
val evaluatedExpr = expr.map(_.eval(input))
if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) {
seen += evaluatedExpr
Expand All @@ -254,7 +252,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag

var result: Any = null

def update(input: Row): Unit = {
override def update(input: Row): Unit = {
if (result == null) {
result = expr.eval(input)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ import org.apache.spark.sql.catalyst.types._
* requested. The attributes produced by this function will be automatically copied anytime rules
* result in changes to the Generator or its children.
*/
abstract class Generator extends Expression with (Row => TraversableOnce[Row]) {
abstract class Generator extends Expression {
self: Product =>

type EvaluatedType = TraversableOnce[Row]
override type EvaluatedType = TraversableOnce[Row]

lazy val dataType =
override lazy val dataType =
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable))))

def nullable = false
override def nullable = false

def references = children.flatMap(_.references).toSet
override def references = children.flatMap(_.references).toSet

/**
* Should be overridden by specific generators. Called only once for each instance to ensure
Expand All @@ -63,7 +63,7 @@ abstract class Generator extends Expression with (Row => TraversableOnce[Row]) {
}

/** Should be implemented by child classes to perform specific Generators. */
def apply(input: Row): TraversableOnce[Row]
override def eval(input: Row): TraversableOnce[Row]

/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
Expand All @@ -83,7 +83,7 @@ case class Explode(attributeNames: Seq[String], child: Expression)
child.resolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])

lazy val elementTypes = child.dataType match {
private lazy val elementTypes = child.dataType match {
case ArrayType(et) => et :: Nil
case MapType(kt,vt) => kt :: vt :: Nil
}
Expand All @@ -100,7 +100,7 @@ case class Explode(attributeNames: Seq[String], child: Expression)
}
}

override def apply(input: Row): TraversableOnce[Row] = {
override def eval(input: Row): TraversableOnce[Row] = {
child.dataType match {
case ArrayType(_) =>
val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {

def output =
override def output =
if (join) child.output ++ generator.output else generator.output

def execute() = {
override def execute() = {
if (join) {
child.execute().mapPartitions { iter =>
val nullValues = Seq.fill(generator.output.size)(Literal(null))
Expand All @@ -52,7 +52,7 @@ case class Generate(
val joinedRow = new JoinedRow

iter.flatMap {row =>
val outputRows = generator(row)
val outputRows = generator.eval(row)
if (outer && outputRows.isEmpty) {
outerProjection(row) :: Nil
} else {
Expand All @@ -61,7 +61,7 @@ case class Generate(
}
}
} else {
child.execute().mapPartitions(iter => iter.flatMap(generator))
child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generato

val Seq(nameAttr, ageAttr) = input

override def apply(input: Row): TraversableOnce[Row] = {
override def eval(input: Row): TraversableOnce[Row] = {
val name = nameAttr.eval(input)
val age = ageAttr.eval(input).asInstanceOf[Int]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ case class HiveGenericUdtf(
}
}

override def apply(input: Row): TraversableOnce[Row] = {
override def eval(input: Row): TraversableOnce[Row] = {
outputInspectors // Make sure initialized.

val inputProjection = new Projection(children)
Expand Down