Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'master' of github.com:apache/spark into SPARK-23916
  • Loading branch information
mgaido91 committed Apr 17, 2018
commit 703c09c4e9da2b96c7a5f445fd5a1d30cdc29c03
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,133 @@ case class ArrayJoin(
override def dataType: DataType = StringType

}

/**
* Returns the minimum value in the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 20, null, 3));
1
""", since = "2.4.0")
case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
if (typeCheckResult.isSuccess) {
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
} else {
typeCheckResult
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
val item = ExprCode("",
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
s"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${childGen.isNull}) {
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
| ${ctx.reassignIfSmaller(dataType, ev, item)}
| }
|}
""".stripMargin)
}

override protected def nullSafeEval(input: Any): Any = {
var min: Any = null
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
if (item != null && (min == null || ordering.lt(item, min))) {
min = item
}
)
min
}

override def dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}

override def prettyName: String = "array_min"
}

/**
* Returns the maximum value in the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 20, null, 3));
20
""", since = "2.4.0")
case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
if (typeCheckResult.isSuccess) {
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
} else {
typeCheckResult
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
val item = ExprCode("",
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
s"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${childGen.isNull}) {
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
| ${ctx.reassignIfGreater(dataType, ev, item)}
| }
|}
""".stripMargin)
}

override protected def nullSafeEval(input: Any): Any = {
var max: Any = null
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
if (item != null && (max == null || ordering.gt(item, max))) {
max = item
}
)
max
}

override def dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}

override def prettyName: String = "array_max"
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Literal(","),
Some(Literal.create(null, StringType))), null)
}

test("Array Min") {
checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11)
checkEvaluation(
ArrayMin(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "")
checkEvaluation(ArrayMin(Literal.create(Seq(null), ArrayType(LongType))), null)
checkEvaluation(ArrayMin(Literal.create(null, ArrayType(StringType))), null)
checkEvaluation(
ArrayMin(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 0.1234)
}

test("Array max") {
checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
checkEvaluation(
ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc")
checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null)
checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null)
checkEvaluation(
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(Row("a,b"), Row("a,NULL,b"), Row("")))
}

test("array_min function") {
val df = Seq(
Seq[Option[Int]](Some(1), Some(3), Some(2)),
Seq.empty[Option[Int]],
Seq[Option[Int]](None),
Seq[Option[Int]](None, Some(1), Some(-100))
).toDF("a")

val answer = Seq(Row(1), Row(null), Row(null), Row(-100))

checkAnswer(df.select(array_min(df("a"))), answer)
checkAnswer(df.selectExpr("array_min(a)"), answer)
}

test("array_max function") {
val df = Seq(
Seq[Option[Int]](Some(1), Some(3), Some(2)),
Seq.empty[Option[Int]],
Seq[Option[Int]](None),
Seq[Option[Int]](None, Some(1), Some(-100))
).toDF("a")

val answer = Seq(Row(3), Row(null), Row(null), Row(1))

checkAnswer(df.select(array_max(df("a"))), answer)
checkAnswer(df.selectExpr("array_max(a)"), answer)
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.