Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
conditional function: least/greatest
  • Loading branch information
adrian-wang committed Jul 10, 2015
commit ec625b0f2267c03cfc5445f0da03038c3b959320
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ object FunctionRegistry {
expression[CreateArray]("array"),
expression[Coalesce]("coalesce"),
expression[Explode]("explode"),
expression[Greatest]("greatest"),
expression[If]("if"),
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
expression[Coalesce]("nvl"),
expression[Rand]("rand"),
expression[Randn]("randn"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,63 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
}.mkString
}
}

case class Least(children: Expression*)
extends Expression {

override def nullable: Boolean = children.forall(_.nullable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't least be nullable if any is nullable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hive 1.2.0 only return null when all arguments are null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any child is not nullable, the result will not be nullable.


override def checkInputDataTypes(): TypeCheckResult = {
if (children.map(_.dataType).distinct.size > 1) {
TypeCheckResult.TypeCheckFailure(
s"differing types in Least (${children.map(_.dataType)}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def dataType: DataType = children.head.dataType

override def eval(input: InternalRow): Any = {
val cmp = GreaterThan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it as the property of the class? We don't want to create new instance for every row, right? And don't forget to mark it as transient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, I just notice it's impossible, but after checking the code in GreateThan, we can borrow the ordering = TypeUtils.getOrdering(left.dataType)

children.foldLeft[Expression](null)((r, c) => {
if (c != null) {
if (r == null || cmp.apply(r, c).eval(input).asInstanceOf[Boolean]) c else r
} else {
r
}
}).eval(input)
}

override def toString: String = s"LEAST(${children.mkString(", ")})"
}

case class Greatest(children: Expression*)
extends Expression {

override def nullable: Boolean = children.forall(_.nullable)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.map(_.dataType).distinct.size > 1) {
TypeCheckResult.TypeCheckFailure(
s"differing types in Greatest (${children.map(_.dataType)}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def dataType: DataType = children.head.dataType

override def eval(input: InternalRow): Any = {
val cmp = LessThan
children.foldLeft[Expression](null)((r, c) => {
if (c != null) {
if (r == null || cmp.apply(r, c).eval(input).asInstanceOf[Boolean]) c else r
} else {
r
}
}).eval(input)
}

override def toString: String = s"LEAST(${children.mkString(", ")})"
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,21 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
}

test("greatest/least") {
val row = create_row(1, 2, "a", "b", "c")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to test this for all data types.

val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.string.at(2)
val c4 = 'a.string.at(3)
val c5 = 'a.string.at(4)
checkEvaluation(Greatest(c4, c5, c3), "c", row)
checkEvaluation(Greatest(c2, c1), 2, row)
checkEvaluation(Least(c4, c3, c5), "a", row)
checkEvaluation(Least(c1, c2), 1, row)
checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row)
checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row)
checkEvaluation(Least(c1, c2, Literal(-1)), -1, row)
checkEvaluation(Least(c4, c5, c3, c3, Literal("a")), "a", row)
}

}
44 changes: 41 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ object functions {
/**
* Creates a new row for each element in the given array or map column.
*/
def explode(e: Column): Column = Explode(e.expr)
def explode(e: Column): Column = Explode(e.expr)

/**
* Converts a string exprsesion to lower case.
Expand Down Expand Up @@ -1073,11 +1073,30 @@ object functions {
def floor(columnName: String): Column = floor(Column(columnName))

/**
* Computes hex value of the given column
* Returns the greatest value of the list of values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also document semantics regarding null values

*
* @group math_funcs
* @group normal_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def greatest(exprs: Column*): Column = Greatest(exprs.map(_.expr): _*)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that exprs have at least two columns? Also for others, it's good to fail fast.


/**
* Returns the greatest value of the list of column names.
*
* @group normal_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def greatest(columnName: String, columnNames: String*): Column =
greatest((columnName +: columnNames).map(Column.apply): _*)

/**
* Computes hex value of the given column
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
/**

  • Computes hex value for the given column.
    *

*
* @group math_funcs
* @since 1.5.0
*/
def hex(column: Column): Column = Hex(column.expr)

/**
Expand Down Expand Up @@ -1171,6 +1190,25 @@ object functions {
*/
def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName))

/**
* Returns the least value of the list of values.
*
* @group normal_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def least(exprs: Column*): Column = Least(exprs.map(_.expr): _*)

/**
* Returns the least value of the list of column names.
*
* @group normal_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def least(columnName: String, columnNames: String*): Column =
least((columnName +: columnNames).map(Column.apply): _*)

/**
* Computes the natural logarithm of the given value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,4 +381,26 @@ class DataFrameFunctionsSuite extends QueryTest {
df.selectExpr("split(a, '[1-9]+')"),
Row(Seq("aa", "bb", "cc")))
}

test("conditional function: least") {
checkAnswer(
testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1),
Row(-1)
)
checkAnswer(
ctx.sql("SELECT least(a, 2) as l from testData2 order by l"),
Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2))
)
}

test("conditional function: greatest") {
checkAnswer(
testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1),
Row(3)
)
checkAnswer(
ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"),
Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
)
}
}