diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 36dcabc6766d8..1be68f2a4a448 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1862,6 +1862,30 @@ def array_position(col, value): return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) +@ignore_unicode_prefix +@since(2.4) +def element_at(col, extraction): + """ + Collection function: Returns element of array at given index in extraction if col is array. + Returns value for the given key in extraction if col is map. + + :param col: name of column containing array or map + :param extraction: index to check for in array or key to check for in map + + .. note:: The position is not zero based, but 1 based index. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(element_at(df.data, 1)).collect() + [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] + + >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) + >>> df.select(element_at(df.data, "a")).collect() + [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 74095fe697b6a..a44f2d5272b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -405,6 +405,7 @@ object FunctionRegistry { expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), + expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e6a05f535cb1c..dba426e999dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -561,3 +561,107 @@ case class ArrayPosition(left: Expression, right: Expression) }) } } + +/** + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + override def dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => left.dataType.asInstanceOf[MapType].keyType + } + ) + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """.stripMargin + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6cdad19168dce..3fba52d745453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `key` in Map `child`. - * - * We need to do type checking here as `key` expression maybe unresolved. + * Common base class for [[GetMapValue]] and [[ElementAt]]. */ -case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { - - private def keyType = child.dataType.asInstanceOf[MapType].keyType - - // We have done type checking for child in `ExtractValue`, so only need to check the `key`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) - - override def toString: String = s"$child[$key]" - override def sql: String = s"${child.sql}[${key.sql}]" - - override def left: Expression = child - override def right: Expression = key - - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType +abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - protected override def nullSafeEval(value: Any, ordinal: Any): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") val values = ctx.freshName("values") - val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) { + val keyType = mapType.keyType + val nullCheck = if (mapType.valueContainsNull) { s" || $values.isNullAt($index)" } else { "" @@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression) }) } } + +/** + * Returns the value of key `key` in Map `child`. + * + * We need to do type checking here as `key` expression maybe unresolved. + */ +case class GetMapValue(child: Expression, key: Expression) + extends GetMapValueUtil with ExtractValue with NullIntolerant { + + private def keyType = child.dataType.asInstanceOf[MapType].keyType + + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) + + override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + + // todo: current search is O(n), improve it. + override def nullSafeEval(value: Any, ordinal: Any): Any = { + getValueEval(value, ordinal, keyType) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType]) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 916cd3bb4cca5..7d8fe211858b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) } + + test("elementAt") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + intercept[Exception] { + checkEvaluation(ElementAt(a0, Literal(0)), null) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } + checkEvaluation(ElementAt(a0, Literal(4)), null) + checkEvaluation(ElementAt(a0, Literal(-4)), null) + + checkEvaluation(ElementAt(a0, Literal(1)), 1) + checkEvaluation(ElementAt(a0, Literal(2)), 2) + checkEvaluation(ElementAt(a0, Literal(3)), 3) + checkEvaluation(ElementAt(a0, Literal(-3)), 1) + checkEvaluation(ElementAt(a0, Literal(-2)), 2) + checkEvaluation(ElementAt(a0, Literal(-1)), 3) + + checkEvaluation(ElementAt(a1, Literal(1)), null) + checkEvaluation(ElementAt(a1, Literal(2)), "") + checkEvaluation(ElementAt(a1, Literal(-2)), null) + checkEvaluation(ElementAt(a1, Literal(-1)), "") + + checkEvaluation(ElementAt(a2, Literal(1)), null) + + checkEvaluation(ElementAt(a3, Literal(1)), null) + + + val m0 = + Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(ElementAt(m0, Literal(1.0)), null) + + checkEvaluation(ElementAt(m0, Literal("d")), null) + + checkEvaluation(ElementAt(m1, Literal("a")), null) + + checkEvaluation(ElementAt(m0, Literal("a")), "1") + checkEvaluation(ElementAt(m0, Literal("b")), "2") + checkEvaluation(ElementAt(m0, Literal("c")), null) + + checkEvaluation(ElementAt(m2, Literal("a")), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a09ec4f1982e..9c8580378303e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3052,6 +3052,17 @@ object functions { ArrayPosition(column.expr, Literal(value)) } + /** + * Returns element of array at given index in value if column is array. Returns value for + * the given key in value if column is map. + * + * @group collection_funcs + * @since 2.4.0 + */ + def element_at(column: Column, value: Any): Column = withExpr { + ElementAt(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 13161e7e24cfe..7c976c1b7f915 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("element_at function") { + val df = Seq( + (Seq[String]("1", "2", "3")), + (Seq[String](null, "")), + (Seq[String]()) + ).toDF("a") + + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 0)), + Seq(Row(null), Row(null), Row(null)) + ) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 1.1)), + Seq(Row(null), Row(null), Row(null)) + ) + } + checkAnswer( + df.select(element_at(df("a"), 4)), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.select(element_at(df("a"), 1)), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -1)), + Seq(Row("3"), Row(""), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 4)"), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 1)"), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -1)"), + Seq(Row("3"), Row(""), Row(null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {