From ec625b0f2267c03cfc5445f0da03038c3b959320 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 17 Jun 2015 00:27:55 -0700 Subject: [PATCH 1/6] conditional function: least/greatest --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../catalyst/expressions/conditionals.scala | 60 +++++++++++++++++++ .../ConditionalExpressionSuite.scala | 17 ++++++ .../org/apache/spark/sql/functions.scala | 44 +++++++++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 22 +++++++ 5 files changed, 142 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f62d79f8cea6..ed69c42dcb82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -76,9 +76,11 @@ object FunctionRegistry { expression[CreateArray]("array"), expression[Coalesce]("coalesce"), expression[Explode]("explode"), + expression[Greatest]("greatest"), expression[If]("if"), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), + expression[Least]("least"), expression[Coalesce]("nvl"), expression[Rand]("rand"), expression[Randn]("randn"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 395e84f089e4..0081ca97f265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -312,3 +312,63 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW }.mkString } } + +case class Least(children: Expression*) + extends Expression { + + override def nullable: Boolean = children.forall(_.nullable) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + s"differing types in Least (${children.map(_.dataType)}).") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: DataType = children.head.dataType + + override def eval(input: InternalRow): Any = { + val cmp = GreaterThan + children.foldLeft[Expression](null)((r, c) => { + if (c != null) { + if (r == null || cmp.apply(r, c).eval(input).asInstanceOf[Boolean]) c else r + } else { + r + } + }).eval(input) + } + + override def toString: String = s"LEAST(${children.mkString(", ")})" +} + +case class Greatest(children: Expression*) + extends Expression { + + override def nullable: Boolean = children.forall(_.nullable) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + s"differing types in Greatest (${children.map(_.dataType)}).") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: DataType = children.head.dataType + + override def eval(input: InternalRow): Any = { + val cmp = LessThan + children.foldLeft[Expression](null)((r, c) => { + if (c != null) { + if (r == null || cmp.apply(r, c).eval(input).asInstanceOf[Boolean]) c else r + } else { + r + } + }).eval(input) + } + + override def toString: String = s"LEAST(${children.mkString(", ")})" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 372848ea9a59..a9c3537a0927 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -134,4 +134,21 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) } + test("greatest/least") { + val row = create_row(1, 2, "a", "b", "c") + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.string.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + checkEvaluation(Greatest(c4, c5, c3), "c", row) + checkEvaluation(Greatest(c2, c1), 2, row) + checkEvaluation(Least(c4, c3, c5), "a", row) + checkEvaluation(Least(c1, c2), 1, row) + checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row) + checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row) + checkEvaluation(Least(c1, c2, Literal(-1)), -1, row) + checkEvaluation(Least(c4, c5, c3, c3, Literal("a")), "a", row) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 08bf37a5c223..f75bf1975aae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -599,7 +599,7 @@ object functions { /** * Creates a new row for each element in the given array or map column. */ - def explode(e: Column): Column = Explode(e.expr) + def explode(e: Column): Column = Explode(e.expr) /** * Converts a string exprsesion to lower case. @@ -1073,11 +1073,30 @@ object functions { def floor(columnName: String): Column = floor(Column(columnName)) /** - * Computes hex value of the given column + * Returns the greatest value of the list of values. * - * @group math_funcs + * @group normal_funcs * @since 1.5.0 */ + @scala.annotation.varargs + def greatest(exprs: Column*): Column = Greatest(exprs.map(_.expr): _*) + + /** + * Returns the greatest value of the list of column names. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def greatest(columnName: String, columnNames: String*): Column = + greatest((columnName +: columnNames).map(Column.apply): _*) + + /** + * Computes hex value of the given column + * + * @group math_funcs + * @since 1.5.0 + */ def hex(column: Column): Column = Hex(column.expr) /** @@ -1171,6 +1190,25 @@ object functions { */ def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + /** + * Returns the least value of the list of values. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def least(exprs: Column*): Column = Least(exprs.map(_.expr): _*) + + /** + * Returns the least value of the list of column names. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def least(columnName: String, columnNames: String*): Column = + least((columnName +: columnNames).map(Column.apply): _*) + /** * Computes the natural logarithm of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 173280375c41..6cebec95d285 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -381,4 +381,26 @@ class DataFrameFunctionsSuite extends QueryTest { df.selectExpr("split(a, '[1-9]+')"), Row(Seq("aa", "bb", "cc"))) } + + test("conditional function: least") { + checkAnswer( + testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), + Row(-1) + ) + checkAnswer( + ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), + Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) + ) + } + + test("conditional function: greatest") { + checkAnswer( + testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1), + Row(3) + ) + checkAnswer( + ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), + Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) + ) + } } From c1f682435db24a530c3952f2231fe193d9ac4b3a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 9 Jul 2015 03:55:03 -0700 Subject: [PATCH 2/6] add codegen, test for all types --- .../catalyst/expressions/conditionals.scala | 87 ++++++++++++++----- .../ConditionalExpressionSuite.scala | 74 ++++++++++++++-- 2 files changed, 133 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 0081ca97f265..ccd453f94e78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.{NullType, BooleanType, DataType} case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -313,62 +314,102 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW } } -case class Least(children: Expression*) - extends Expression { +case class Least(children: Expression*) extends Expression { + require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.size > 1) { + if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( - s"differing types in Least (${children.map(_.dataType)}).") + s"The expressions should all have the same type," + + s" got LEAST (${children.map(_.dataType)}).") } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) } } override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { - val cmp = GreaterThan - children.foldLeft[Expression](null)((r, c) => { - if (c != null) { - if (r == null || cmp.apply(r, c).eval(input).asInstanceOf[Boolean]) c else r + children.foldLeft[Any](null)((r, c) => { + val evalc = c.eval(input) + if (evalc != null) { + if (r == null || ordering.lt(evalc, r)) evalc else r } else { r } - }).eval(input) + }) } - override def toString: String = s"LEAST(${children.mkString(", ")})" + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val evalChildren = children.map(_.gen(ctx)) + def updateEval(i: Int): String = + s""" + if (${ev.isNull} || (!${evalChildren(i).isNull} && ${ + ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) { + ${ev.isNull} = ${evalChildren(i).isNull}; + ${ev.primitive} = ${evalChildren(i).primitive}; + } + """ + s""" + ${evalChildren.map(_.code).mkString("\n")} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${(0 to children.length - 1).map(updateEval).mkString("\n")} + """ + } } -case class Greatest(children: Expression*) - extends Expression { +case class Greatest(children: Expression*) extends Expression { + require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.size > 1) { + if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( - s"differing types in Greatest (${children.map(_.dataType)}).") + s"The expressions should all have the same type," + + s" got GREATEST (${children.map(_.dataType)}).") } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) } } override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { - val cmp = LessThan - children.foldLeft[Expression](null)((r, c) => { - if (c != null) { - if (r == null || cmp.apply(r, c).eval(input).asInstanceOf[Boolean]) c else r + children.foldLeft[Any](null)((r, c) => { + val evalc = c.eval(input) + if (evalc != null) { + if (r == null || ordering.gt(evalc, r)) evalc else r } else { r } - }).eval(input) + }) } - override def toString: String = s"LEAST(${children.mkString(", ")})" + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val evalChildren = children.map(_.gen(ctx)) + def updateEval(i: Int): String = + s""" + if (${ev.isNull} || (!${evalChildren(i).isNull} && ${ + ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) { + ${ev.isNull} = ${evalChildren(i).isNull}; + ${ev.primitive} = ${evalChildren(i).primitive}; + } + """ + s""" + ${evalChildren.map(_.code).mkString("\n")} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${(0 to children.length - 1).map(updateEval).mkString("\n")} + """ + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index a9c3537a0927..aaf40cc83e76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Timestamp, Date} + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -134,21 +137,82 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) } - test("greatest/least") { + test("function least") { val row = create_row(1, 2, "a", "b", "c") val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.string.at(2) val c4 = 'a.string.at(3) val c5 = 'a.string.at(4) - checkEvaluation(Greatest(c4, c5, c3), "c", row) - checkEvaluation(Greatest(c2, c1), 2, row) checkEvaluation(Least(c4, c3, c5), "a", row) checkEvaluation(Least(c1, c2), 1, row) - checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row) - checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row) checkEvaluation(Least(c1, c2, Literal(-1)), -1, row) checkEvaluation(Least(c4, c5, c3, c3, Literal("a")), "a", row) + + checkEvaluation(Least(Literal(null), Literal(null)), null, InternalRow.empty) + checkEvaluation(Least(Literal(-1.0), Literal(2.5)), -1.0, InternalRow.empty) + checkEvaluation(Least(Literal(-1), Literal(2)), -1, InternalRow.empty) + checkEvaluation( + Least(Literal((-1.0).toFloat), Literal(2.5.toFloat)), (-1.0).toFloat, InternalRow.empty) + checkEvaluation( + Least(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MinValue, InternalRow.empty) + checkEvaluation(Least(Literal(1.toByte), Literal(2.toByte)), 1.toByte, InternalRow.empty) + checkEvaluation( + Least(Literal(1.toShort), Literal(2.toByte.toShort)), 1.toShort, InternalRow.empty) + checkEvaluation(Least(Literal("abc"), Literal("aaaa")), "aaaa", InternalRow.empty) + checkEvaluation(Least(Literal(true), Literal(false)), false, InternalRow.empty) + checkEvaluation( + Least( + Literal(BigDecimal("1234567890987654321123456")), + Literal(BigDecimal("1234567890987654321123458"))), + BigDecimal("1234567890987654321123456"), InternalRow.empty) + checkEvaluation( + Least(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))), + Date.valueOf("2015-01-01"), InternalRow.empty) + checkEvaluation( + Least( + Literal(Timestamp.valueOf("2015-07-01 08:00:00")), + Literal(Timestamp.valueOf("2015-07-01 10:00:00"))), + Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) + } + + test("function greatest") { + val row = create_row(1, 2, "a", "b", "c") + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.string.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + checkEvaluation(Greatest(c4, c5, c3), "c", row) + checkEvaluation(Greatest(c2, c1), 2, row) + checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row) + checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row) + + checkEvaluation(Greatest(Literal(null), Literal(null)), null, InternalRow.empty) + checkEvaluation(Greatest(Literal(-1.0), Literal(2.5)), 2.5, InternalRow.empty) + checkEvaluation(Greatest(Literal(-1), Literal(2)), 2, InternalRow.empty) + checkEvaluation( + Greatest(Literal((-1.0).toFloat), Literal(2.5.toFloat)), 2.5.toFloat, InternalRow.empty) + checkEvaluation( + Greatest(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MaxValue, InternalRow.empty) + checkEvaluation(Greatest(Literal(1.toByte), Literal(2.toByte)), 2.toByte, InternalRow.empty) + checkEvaluation( + Greatest(Literal(1.toShort), Literal(2.toByte.toShort)), 2.toShort, InternalRow.empty) + checkEvaluation(Greatest(Literal("abc"), Literal("aaaa")), "abc", InternalRow.empty) + checkEvaluation(Greatest(Literal(true), Literal(false)), true, InternalRow.empty) + checkEvaluation( + Greatest( + Literal(BigDecimal("1234567890987654321123456")), + Literal(BigDecimal("1234567890987654321123458"))), + BigDecimal("1234567890987654321123458"), InternalRow.empty) + checkEvaluation( + Greatest(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))), + Date.valueOf("2015-07-01"), InternalRow.empty) + checkEvaluation( + Greatest( + Literal(Timestamp.valueOf("2015-07-01 08:00:00")), + Literal(Timestamp.valueOf("2015-07-01 10:00:00"))), + Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) } } From 7a6bdbb175b6bcfd4bae0b5a6ee2a071056a8e36 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 9 Jul 2015 20:17:23 -0700 Subject: [PATCH 3/6] add '.' for hex() --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f75bf1975aae..3239368c9a29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1092,7 +1092,7 @@ object functions { greatest((columnName +: columnNames).map(Column.apply): _*) /** - * Computes hex value of the given column + * Computes hex value of the given column. * * @group math_funcs * @since 1.5.0 @@ -1100,7 +1100,7 @@ object functions { def hex(column: Column): Column = Hex(column.expr) /** - * Computes hex value of the given input + * Computes hex value of the given input. * * @group math_funcs * @since 1.5.0 From 0f1bff249b4a8ecc25d21e0c4e0ea4f174d9b092 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 12 Jul 2015 20:08:00 -0700 Subject: [PATCH 4/6] address comments from davis --- .../catalyst/expressions/conditionals.scala | 16 +++++++------- .../org/apache/spark/sql/functions.scala | 22 +++++++++++++++---- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index ccd453f94e78..e6a705fb8055 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -349,9 +349,9 @@ case class Least(children: Expression*) extends Expression { val evalChildren = children.map(_.gen(ctx)) def updateEval(i: Int): String = s""" - if (${ev.isNull} || (!${evalChildren(i).isNull} && ${ - ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) { - ${ev.isNull} = ${evalChildren(i).isNull}; + if (!${evalChildren(i).isNull} && (${ev.isNull} || + ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) { + ${ev.isNull} = false; ${ev.primitive} = ${evalChildren(i).primitive}; } """ @@ -359,7 +359,7 @@ case class Least(children: Expression*) extends Expression { ${evalChildren.map(_.code).mkString("\n")} boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${(0 to children.length - 1).map(updateEval).mkString("\n")} + ${(0 until children.length).map(updateEval).mkString("\n")} """ } } @@ -399,9 +399,9 @@ case class Greatest(children: Expression*) extends Expression { val evalChildren = children.map(_.gen(ctx)) def updateEval(i: Int): String = s""" - if (${ev.isNull} || (!${evalChildren(i).isNull} && ${ - ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) { - ${ev.isNull} = ${evalChildren(i).isNull}; + if (!${evalChildren(i).isNull} && (${ev.isNull} || + ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) { + ${ev.isNull} = false; ${ev.primitive} = ${evalChildren(i).primitive}; } """ @@ -409,7 +409,7 @@ case class Greatest(children: Expression*) extends Expression { ${evalChildren.map(_.code).mkString("\n")} boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${(0 to children.length - 1).map(updateEval).mkString("\n")} + ${(0 until children.length).map(updateEval).mkString("\n")} """ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3239368c9a29..ffa52f62588d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1079,7 +1079,11 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = Greatest(exprs.map(_.expr): _*) + def greatest(exprs: Column*): Column = if (exprs.length < 2) { + sys.error("GREATEST takes at least 2 parameters") + } else { + Greatest(exprs.map(_.expr): _*) + } /** * Returns the greatest value of the list of column names. @@ -1088,8 +1092,11 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(columnName: String, columnNames: String*): Column = + def greatest(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { + sys.error("GREATEST takes at least 2 parameters") + } else { greatest((columnName +: columnNames).map(Column.apply): _*) + } /** * Computes hex value of the given column. @@ -1197,7 +1204,11 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = Least(exprs.map(_.expr): _*) + def least(exprs: Column*): Column = if (exprs.length < 2) { + sys.error("LEAST takes at least 2 parameters") + } else { + Least(exprs.map(_.expr): _*) + } /** * Returns the least value of the list of column names. @@ -1206,8 +1217,11 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(columnName: String, columnNames: String*): Column = + def least(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { + sys.error("LEAST takes at least 2 parameters") + } else { least((columnName +: columnNames).map(Column.apply): _*) + } /** * Computes the natural logarithm of the given value. From 86fb049829b8393892d6fc0b7b47e21879fe8b20 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 13 Jul 2015 16:03:10 +0800 Subject: [PATCH 5/6] use seq for case class --- .../apache/spark/sql/catalyst/expressions/conditionals.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index e6a705fb8055..f1697845551c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -314,7 +314,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW } } -case class Least(children: Expression*) extends Expression { +case class Least(children: Seq[Expression]) extends Expression { require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) @@ -364,7 +364,7 @@ case class Least(children: Expression*) extends Expression { } } -case class Greatest(children: Expression*) extends Expression { +case class Greatest(children: Seq[Expression]) extends Expression { require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) From 22e8f3d786a6903c6e0996b8003ca38b6f67f3f9 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 13 Jul 2015 01:06:55 -0700 Subject: [PATCH 6/6] use seq for case class in functions --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ffa52f62588d..606b4c418722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1082,7 +1082,7 @@ object functions { def greatest(exprs: Column*): Column = if (exprs.length < 2) { sys.error("GREATEST takes at least 2 parameters") } else { - Greatest(exprs.map(_.expr): _*) + Greatest(exprs.map(_.expr)) } /** @@ -1207,7 +1207,7 @@ object functions { def least(exprs: Column*): Column = if (exprs.length < 2) { sys.error("LEAST takes at least 2 parameters") } else { - Least(exprs.map(_.expr): _*) + Least(exprs.map(_.expr)) } /**