-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23919][SQL] Add array_position function #21037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fb04539
3d8069d
5339a65
f1238b6
07305db
3a16231
d4cebed
9a0321d
7362b1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -505,3 +505,59 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast | |
|
|
||
| override def prettyName: String = "array_max" | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Returns the position of the first occurrence of element in the given array as long. | ||
| * Returns 0 if the given value could not be found in the array. Returns null if either of | ||
| * the arguments are null | ||
| * | ||
| * NOTE: that this is not zero based, but 1-based index. The first element in the array has | ||
| * index 1. | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = """ | ||
| _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. | ||
| """, | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(3, 2, 1), 1); | ||
| 3 | ||
| """, | ||
| since = "2.4.0") | ||
| case class ArrayPosition(left: Expression, right: Expression) | ||
|
||
| extends BinaryExpression with ImplicitCastInputTypes { | ||
|
|
||
| override def dataType: DataType = LongType | ||
| override def inputTypes: Seq[AbstractDataType] = | ||
| Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) | ||
|
|
||
| override def nullSafeEval(arr: Any, value: Any): Any = { | ||
| arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => | ||
| if (v == value) { | ||
| return (i + 1).toLong | ||
| } | ||
| ) | ||
| 0L | ||
| } | ||
|
|
||
| override def prettyName: String = "array_position" | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, (arr, value) => { | ||
| val pos = ctx.freshName("arrayPosition") | ||
| val i = ctx.freshName("i") | ||
| val getValue = CodeGenerator.getValue(arr, right.dataType, i) | ||
| s""" | ||
| |int $pos = 0; | ||
| |for (int $i = 0; $i < $arr.numElements(); $i ++) { | ||
| | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { | ||
| | $pos = $i + 1; | ||
| | break; | ||
| | } | ||
| |} | ||
| |${ev.value} = (long) $pos; | ||
| """.stripMargin | ||
| }) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to note that we can use
notehere too:spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
Lines 101 to 103 in 2ce37b5
I am mentioning this because we are adding many functions now :-).