Skip to content
Closed
Prev Previous commit
Next Next commit
Merge branch 'master' into SPARK-23939
  • Loading branch information
codeatri authored Aug 14, 2018
commit bb52630dd720ecaf5f7ffe0c498d422ce60f3bb7
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ object FunctionRegistry {
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
expression[TransformKeys]("transform_keys"),
expression[MapZipWith]("map_zip_with"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,4 +557,195 @@ case class TransformKeys(
}

override def prettyName: String = "transform_keys"
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent


/**
* Merges two given maps into a single map by applying function to the pair of values with
* the same key.
*/
@ExpressionDescription(
usage =
"""
_FUNC_(map1, map2, function) - Merges two given maps into a single map by applying
function to the pair of values with the same key. For keys only presented in one map,
NULL will be passed as the value for the missing key. If an input map contains duplicated
keys, only the first entry of the duplicated key is passed into the lambda function.
""",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2));
{1:"ax",2:"by"}
""",
since = "2.4.0")
case class MapZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {

def functionForEval: Expression = functionsForEval.head

@transient lazy val MapType(leftKeyType, leftValueType, leftValueContainsNull) = left.dataType

@transient lazy val MapType(rightKeyType, rightValueType, rightValueContainsNull) = right.dataType

@transient lazy val keyType =
TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(leftKeyType, rightKeyType).get

@transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType)

override def arguments: Seq[Expression] = left :: right :: Nil

override def argumentTypes: Seq[AbstractDataType] = MapType :: MapType :: Nil

override def functions: Seq[Expression] = function :: Nil

override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil

override def nullable: Boolean = left.nullable || right.nullable

override def dataType: DataType = MapType(keyType, function.dataType, function.nullable)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = {
val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true))
copy(function = f(function, arguments))
}

override def checkArgumentDataTypes(): TypeCheckResult = {
super.checkArgumentDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (leftKeyType.sameType(rightKeyType)) {
TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName")
} else {
TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " +
s"been two ${MapType.simpleString}s with compatible key types, but the key types are " +
s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].")
}
case failure => failure
}
}

override def checkInputDataTypes(): TypeCheckResult = checkArgumentDataTypes()

override def eval(input: InternalRow): Any = {
val value1 = left.eval(input)
if (value1 == null) {
null
} else {
val value2 = right.eval(input)
if (value2 == null) {
null
} else {
nullSafeEval(input, value1, value2)
}
}
}

@transient lazy val LambdaFunction(_, Seq(
keyVar: NamedLambdaVariable,
value1Var: NamedLambdaVariable,
value2Var: NamedLambdaVariable),
_) = function

private def keyTypeSupportsEquals = keyType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

/**
* The function accepts two key arrays and returns a collection of keys with indexes
* to value arrays. Indexes are represented as an array of two items. This is a small
* optimization leveraging mutability of arrays.
*/
@transient private lazy val getKeysWithValueIndexes:
(ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = {
if (keyTypeSupportsEquals) {
getKeysWithIndexesFast
} else {
getKeysWithIndexesBruteForce
}
}

private def assertSizeOfArrayBuffer(size: Int): Unit = {
if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to zip maps with $size " +
s"unique keys due to exceeding the array size limit " +
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
}

private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = {
val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]]
for((z, array) <- Array((0, keys1), (1, keys2))) {
var i = 0
while (i < array.numElements()) {
val key = array.get(i, keyType)
hashMap.get(key) match {
case Some(indexes) =>
if (indexes(z).isEmpty) {
indexes(z) = Some(i)
}
case None =>
val indexes = Array[Option[Int]](None, None)
indexes(z) = Some(i)
hashMap.put(key, indexes)
}
i += 1
}
}
hashMap
}

private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = {
val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])]
for((z, array) <- Array((0, keys1), (1, keys2))) {
var i = 0
while (i < array.numElements()) {
val key = array.get(i, keyType)
var found = false
var j = 0
while (!found && j < arrayBuffer.size) {
val (bufferKey, indexes) = arrayBuffer(j)
if (ordering.equiv(bufferKey, key)) {
found = true
if(indexes(z).isEmpty) {
indexes(z) = Some(i)
}
}
j += 1
}
if (!found) {
assertSizeOfArrayBuffer(arrayBuffer.size)
val indexes = Array[Option[Int]](None, None)
indexes(z) = Some(i)
arrayBuffer += Tuple2(key, indexes)
}
i += 1
}
}
arrayBuffer
}

private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = {
val mapData1 = value1.asInstanceOf[MapData]
val mapData2 = value2.asInstanceOf[MapData]
val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray())
val size = keysWithIndexes.size
val keys = new GenericArrayData(new Array[Any](size))
val values = new GenericArrayData(new Array[Any](size))
val valueData1 = mapData1.valueArray()
val valueData2 = mapData2.valueArray()
var i = 0
for ((key, Array(index1, index2)) <- keysWithIndexes) {
val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null)
val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null)
keyVar.value.set(key)
value1Var.value.set(v1)
value2Var.value.set(v2)
keys.update(i, key)
values.update(i, functionForEval.eval(inputRow))
i += 1
}
new ArrayBasedMapData(keys, values)
}

override def prettyName: String = "map_zip_with"
}
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,118 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper

checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z"))
}

test("MapZipWith") {
def map_zip_with(
left: Expression,
right: Expression,
f: (Expression, Expression, Expression) => Expression): Expression = {
val MapType(kt, vt1, vcn1) = left.dataType.asInstanceOf[MapType]
val MapType(_, vt2, vcn2) = right.dataType.asInstanceOf[MapType]
MapZipWith(left, right, createLambda(kt, false, vt1, vcn1, vt2, vcn2, f))
}

val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii4 = MapFromArrays(
Literal.create(Seq(2, 2), ArrayType(IntegerType, false)),
Literal.create(Seq(20, 200), ArrayType(IntegerType, false)))
val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val multiplyKeyWithValues: (Expression, Expression, Expression) => Expression = {
(k, v1, v2) => k * v1 * v2
}

checkEvaluation(
map_zip_with(mii0, mii1, multiplyKeyWithValues),
Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null))
checkEvaluation(
map_zip_with(mii0, mii2, multiplyKeyWithValues),
Map(1 -> null, 2 -> -80, 3 -> null))
checkEvaluation(
map_zip_with(mii0, mii3, multiplyKeyWithValues),
Map(1 -> null, 2 -> null, 3 -> null))
checkEvaluation(
map_zip_with(mii0, mii4, multiplyKeyWithValues),
Map(1 -> null, 2 -> 800, 3 -> null))
checkEvaluation(
map_zip_with(mii4, mii0, multiplyKeyWithValues),
Map(2 -> 800, 1 -> null, 3 -> null))
checkEvaluation(
map_zip_with(mii0, miin, multiplyKeyWithValues),
null)

val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"),
MapType(StringType, StringType, valueContainsNull = false))
val mss1 = Literal.create(Map("d" -> "b", "b" -> "d"),
MapType(StringType, StringType, valueContainsNull = false))
val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val mss3 = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false))
val mss4 = MapFromArrays(
Literal.create(Seq("a", "a"), ArrayType(StringType, false)),
Literal.create(Seq("a", "n"), ArrayType(StringType, false)))
val mssn = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false))

val concat: (Expression, Expression, Expression) => Expression = {
(k, v1, v2) => Concat(Seq(k, v1, v2))
}

checkEvaluation(
map_zip_with(mss0, mss1, concat),
Map("a" -> null, "b" -> "byd", "d" -> "dzb"))
checkEvaluation(
map_zip_with(mss1, mss2, concat),
Map("d" -> null, "b" -> "bdt", "c" -> null, "a" -> null))
checkEvaluation(
map_zip_with(mss0, mss3, concat),
Map("a" -> null, "b" -> null, "d" -> null))
checkEvaluation(
map_zip_with(mss0, mss4, concat),
Map("a" -> "axa", "b" -> null, "d" -> null))
checkEvaluation(
map_zip_with(mss4, mss0, concat),
Map("a" -> "aax", "b" -> null, "d" -> null))
checkEvaluation(
map_zip_with(mss0, mssn, concat),
null)

def b(data: Byte*): Array[Byte] = Array[Byte](data: _*)

val mbb0 = Literal.create(Map(b(1, 2) -> b(4), b(2, 1) -> b(5), b(1, 3) -> b(8)),
MapType(BinaryType, BinaryType, valueContainsNull = false))
val mbb1 = Literal.create(Map(b(2, 1) -> b(7), b(1, 2) -> b(3), b(1, 1) -> b(6)),
MapType(BinaryType, BinaryType, valueContainsNull = false))
val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> null),
MapType(BinaryType, BinaryType, valueContainsNull = true))
val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, valueContainsNull = false))
val mbb4 = MapFromArrays(
Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)),
Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false)))
val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, valueContainsNull = false))

checkEvaluation(
map_zip_with(mbb0, mbb1, concat),
Map(b(1, 2) -> b(1, 2, 4, 3), b(2, 1) -> b(2, 1, 5, 7), b(1, 3) -> null, b(1, 1) -> null))
checkEvaluation(
map_zip_with(mbb1, mbb2, concat),
Map(b(2, 1) -> null, b(1, 2) -> b(1, 2, 3, 2), b(1, 1) -> null, b(1, 3) -> null))
checkEvaluation(
map_zip_with(mbb0, mbb3, concat),
Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null))
checkEvaluation(
map_zip_with(mbb0, mbb4, concat),
Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null))
checkEvaluation(
map_zip_with(mbb4, mbb0, concat),
Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null))
checkEvaluation(
map_zip_with(mbb0, mbbn, concat),
null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as
-- Aggregate a null array
select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v;

-- Check for element existence
select exists(ys, y -> y > 30) as v from nested;

-- Check for element existence in a null array
select exists(cast(null as array<int>), y -> y > 30) as v;

create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

  (1, map(1, 1, 2, 2, 3, 3)),
  (2, map(4, 4, 5, 5, 6, 6))

Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.