-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23921][SQL] Add array_sort function #21021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
72f31b3
a9b6e3b
d57c14a
9977f64
172b2c5
f2798f9
175d981
d1b0483
04a3ae5
9f63a76
e3fcaaa
2ad6bb8
2c4404c
21521d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ import java.util.Comparator | |
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.expressions.ArraySortUtil.NullOrder | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} | ||
| import org.apache.spark.sql.types._ | ||
|
|
@@ -117,47 +118,16 @@ case class MapValues(child: Expression) | |
| } | ||
|
|
||
| /** | ||
| * Sorts the input array in ascending / descending order according to the natural ordering of | ||
| * the array elements and returns it. | ||
| * Common base class for [[SortArray]] and [[ArraySort]]. | ||
| */ | ||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true); | ||
| ["a","b","c","d"] | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class SortArray(base: Expression, ascendingOrder: Expression) | ||
| extends BinaryExpression with ExpectsInputTypes with CodegenFallback { | ||
|
|
||
| def this(e: Expression) = this(e, Literal(true)) | ||
| trait ArraySortUtil extends ExpectsInputTypes { | ||
| protected def arrayExpression: Expression | ||
|
|
||
| override def left: Expression = base | ||
| override def right: Expression = ascendingOrder | ||
| override def dataType: DataType = base.dataType | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = base.dataType match { | ||
| case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => | ||
| ascendingOrder match { | ||
| case Literal(_: Boolean, BooleanType) => | ||
| TypeCheckResult.TypeCheckSuccess | ||
| case _ => | ||
| TypeCheckResult.TypeCheckFailure( | ||
| "Sort order in second argument requires a boolean literal.") | ||
| } | ||
| case ArrayType(dt, _) => | ||
| TypeCheckResult.TypeCheckFailure( | ||
| s"$prettyName does not support sorting array of type ${dt.simpleString}") | ||
| case _ => | ||
| TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") | ||
| } | ||
| protected def nullOrder: NullOrder | ||
|
|
||
| @transient | ||
| private lazy val lt: Comparator[Any] = { | ||
| val ordering = base.dataType match { | ||
| val ordering = arrayExpression.dataType match { | ||
| case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] | ||
| case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] | ||
| case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] | ||
|
|
@@ -168,9 +138,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) | |
| if (o1 == null && o2 == null) { | ||
| 0 | ||
| } else if (o1 == null) { | ||
| -1 | ||
| nullOrder | ||
| } else if (o2 == null) { | ||
| 1 | ||
| -nullOrder | ||
| } else { | ||
| ordering.compare(o1, o2) | ||
| } | ||
|
|
@@ -180,7 +150,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) | |
|
|
||
| @transient | ||
| private lazy val gt: Comparator[Any] = { | ||
| val ordering = base.dataType match { | ||
| val ordering = arrayExpression.dataType match { | ||
| case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] | ||
| case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] | ||
| case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] | ||
|
|
@@ -191,28 +161,205 @@ case class SortArray(base: Expression, ascendingOrder: Expression) | |
| if (o1 == null && o2 == null) { | ||
| 0 | ||
| } else if (o1 == null) { | ||
| 1 | ||
| -nullOrder | ||
| } else if (o2 == null) { | ||
| -1 | ||
| nullOrder | ||
| } else { | ||
| -ordering.compare(o1, o2) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def nullSafeEval(array: Any, ascending: Any): Any = { | ||
| val elementType = base.dataType.asInstanceOf[ArrayType].elementType | ||
| def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType | ||
|
|
||
| def sortEval(array: Any, ascending: Boolean): Any = { | ||
| val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) | ||
| if (elementType != NullType) { | ||
| java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) | ||
| java.util.Arrays.sort(data, if (ascending) lt else gt) | ||
| } | ||
| new GenericArrayData(data.asInstanceOf[Array[Any]]) | ||
| } | ||
|
|
||
| def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { | ||
| val arrayData = classOf[ArrayData].getName | ||
| val genericArrayData = classOf[GenericArrayData].getName | ||
| val array = ctx.freshName("array") | ||
| val c = ctx.freshName("c") | ||
| val sort = if (elementType == NullType) "" else { | ||
|
||
| val sortOrder = ctx.freshName("sortOrder") | ||
| val o1 = ctx.freshName("o1") | ||
| val o2 = ctx.freshName("o2") | ||
| val jt = CodeGenerator.javaType(elementType) | ||
| val comp = if (CodeGenerator.isPrimitiveType(elementType)) { | ||
| val bt = CodeGenerator.boxedType(elementType) | ||
| val v1 = ctx.freshName("v1") | ||
| val v2 = ctx.freshName("v2") | ||
| s""" | ||
| |$jt $v1 = (($bt) $o1).${jt}Value(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to enforce the boxing? An why do we need to cast to the java type in the non primitive scenario?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, this is because
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now I see, thanks. |
||
| |$jt $v2 = (($bt) $o2).${jt}Value(); | ||
| |int $c = ${ctx.genComp(elementType, v1, v2)}; | ||
| """.stripMargin | ||
| } else { | ||
| s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" | ||
| } | ||
| s""" | ||
| |final int $sortOrder = $order ? 1 : -1; | ||
| |java.util.Arrays.sort($array, new java.util.Comparator() { | ||
| | @Override public int compare(Object $o1, Object $o2) { | ||
| | if ($o1 == null && $o2 == null) { | ||
| | return 0; | ||
| | } else if ($o1 == null) { | ||
| | return $sortOrder * $nullOrder; | ||
| | } else if ($o2 == null) { | ||
| | return -$sortOrder * $nullOrder; | ||
| | } | ||
| | $comp | ||
| | return $sortOrder * $c; | ||
| | } | ||
| |}); | ||
| """.stripMargin | ||
| } | ||
| val dataTypes = elementType match { | ||
| case DecimalType.Fixed(p, s) => | ||
| s"org.apache.spark.sql.types.DataTypes.createDecimalType($p, $s)" | ||
| case ArrayType(et, cn) => | ||
| s"org.apache.spark.sql.types.DataTypes.createArrayType($et, $cn)" | ||
| case MapType(kt, vt, cn) => | ||
|
||
| s"org.apache.spark.sql.types.DataTypes.createMapType($kt, $vt, $cn)" | ||
| case StructType(f) => | ||
| "org.apache.spark.sql.types.StructType$.MODULE$." + | ||
| s"apply(new java.util.ArrayList(${f.length}))" | ||
| case _ => | ||
| s"org.apache.spark.sql.types.DataTypes.$elementType" | ||
| } | ||
|
||
| s""" | ||
| |Object[] $array = (Object[]) (($arrayData) $base).toArray( | ||
|
||
| | $dataTypes, scala.reflect.ClassTag$$.MODULE$$.AnyRef()); | ||
| |$sort | ||
| |${ev.value} = new $genericArrayData($array); | ||
| """.stripMargin | ||
| } | ||
|
|
||
| } | ||
|
|
||
| object ArraySortUtil { | ||
| type NullOrder = Int | ||
| // Least: place null element at the first of the array for ascending order | ||
| // Greatest: place null element at the end of the array for ascending order | ||
| object NullOrder { | ||
| val Least: NullOrder = -1 | ||
| val Greatest: NullOrder = 1 | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Sorts the input array in ascending / descending order according to the natural ordering of | ||
| * the array elements and returns it. | ||
| */ | ||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
| usage = """ | ||
| _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order | ||
| according to the natural ordering of the array elements. Null elements will be placed | ||
| at the beginning of the returned array in ascending order or at the end of the returned | ||
| array in descending order. | ||
| """, | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); | ||
| [null,"a","b","c","d"] | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class SortArray(base: Expression, ascendingOrder: Expression) | ||
| extends BinaryExpression with ArraySortUtil { | ||
|
|
||
| def this(e: Expression) = this(e, Literal(true)) | ||
|
|
||
| override def left: Expression = base | ||
| override def right: Expression = ascendingOrder | ||
| override def dataType: DataType = base.dataType | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) | ||
|
|
||
| override def arrayExpression: Expression = base | ||
| override def nullOrder: NullOrder = NullOrder.Least | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = base.dataType match { | ||
| case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => | ||
| ascendingOrder match { | ||
| case Literal(_: Boolean, BooleanType) => | ||
| TypeCheckResult.TypeCheckSuccess | ||
| case _ => | ||
| TypeCheckResult.TypeCheckFailure( | ||
| "Sort order in second argument requires a boolean literal.") | ||
| } | ||
| case ArrayType(dt, _) => | ||
| val dtSimple = dt.simpleString | ||
| TypeCheckResult.TypeCheckFailure( | ||
| s"$prettyName does not support sorting array of type $dtSimple which is not orderable") | ||
| case _ => | ||
| TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") | ||
| } | ||
|
|
||
| override def nullSafeEval(array: Any, ascending: Any): Any = { | ||
| sortEval(array, ascending.asInstanceOf[Boolean]) | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) | ||
| } | ||
|
|
||
| override def prettyName: String = "sort_array" | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Sorts the input array in ascending order according to the natural ordering of | ||
| * the array elements and returns it. | ||
| */ | ||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
| usage = """ | ||
| _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must | ||
| be orderable. Null elements will be placed at the end of the returned array. | ||
| """, | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); | ||
| ["a","b","c","d",null] | ||
| """, | ||
| since = "2.4.0") | ||
| // scalastyle:on line.size.limit | ||
| case class ArraySort(child: Expression) extends UnaryExpression with ArraySortUtil { | ||
|
|
||
| override def dataType: DataType = child.dataType | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) | ||
|
|
||
| override def arrayExpression: Expression = child | ||
| override def nullOrder: NullOrder = NullOrder.Greatest | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = child.dataType match { | ||
| case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => | ||
| TypeCheckResult.TypeCheckSuccess | ||
| case ArrayType(dt, _) => | ||
| val dtSimple = dt.simpleString | ||
| TypeCheckResult.TypeCheckFailure( | ||
| s"$prettyName does not support sorting array of type $dtSimple which is not orderable") | ||
| case _ => | ||
| TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") | ||
| } | ||
|
|
||
| override def nullSafeEval(array: Any): Any = { | ||
| sortEval(array, true) | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) | ||
| } | ||
|
|
||
| override def prettyName: String = "array_sort" | ||
| } | ||
|
|
||
| /** | ||
| * Returns a reversed string or an array with reverse order of elements. | ||
| */ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
ArraySortLike?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, thank you for your review while it is a long-holiday week in Japan.