Skip to content
Prev Previous commit
Next Next commit
fix pyspark test failure
  • Loading branch information
kiszk committed Jun 8, 2018
commit 4eee89da3a80f679b8ff0a631c4374ae6fa0de86
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,8 @@ case class CreateMapFromArrays(left: Expression, right: Expression)

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (ArrayType(_, cn), ArrayType(_, _)) =>
if (!cn) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("All of the given keys should be non-null")
}
case (ArrayType(_, _), ArrayType(_, _)) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure("The given two arguments should be an array")
}
Expand All @@ -281,18 +277,39 @@ case class CreateMapFromArrays(left: Expression, right: Expression)
if (keyArrayData.numElements != valueArrayData.numElements) {
throw new RuntimeException("The given two arrays should have the same length")
}
val leftArrayType = left.dataType.asInstanceOf[ArrayType]
if (leftArrayType.containsNull) {
if (keyArrayData.toArray(leftArrayType.elementType).contains(null)) {
throw new RuntimeException("Cannot use null as map key!")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use loop to null-check without converting to object array?

}
new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy())
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => {
val arrayBasedMapData = classOf[ArrayBasedMapData].getName
val leftArrayType = left.dataType.asInstanceOf[ArrayType]
val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else {
val leftArrayTypeTerm = ctx.addReferenceObj("leftArrayType", leftArrayType.elementType)
val array = ctx.freshName("array")
val i = ctx.freshName("i")
s"""
|Object[] $array = $keyArrayData.toObjectArray($leftArrayTypeTerm);
|for (int $i = 0; $i < $array.length; $i++) {
| if ($array[$i] == null) {
| throw new RuntimeException("Cannot use null as map key!");
| }
|}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can null-check without converting to object array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I realized we have to evaluate each element as CreateMap does. I think that we have to update eval and codegen.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, but I couldn't get it. I might miss something, but I thought we can simply do like:

for (int $i = 0; $i < $keyArrayData.numElements(); $i++) {
  if ($keyArrayData.isNullAt($i)) {
    throw new RuntimeException("Cannot use null as map key!");
  }
}

Doesn't this work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code should work if we evaluate each element to make isNullAt() valid.

I think that my mistake is not to currently evaluate each element in keyArrayData and valueArrayData.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. An array has been evaluated.

""".stripMargin
}
s"""
|if ($keyArrayData.numElements() != $valueArrayData.numElements()) {
| throw new RuntimeException("The given two arrays should have the same length");
|}
|$keyArrayElemNullCheck
|${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy());
"""
""".stripMargin
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,23 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
val intDupSeq = Seq(5, 10, 15, 15, 5)
val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25)
val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25)
val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_))

val intArray = Literal.create(intSeq, ArrayType(IntegerType, false))
val longArray = Literal.create(longSeq, ArrayType(LongType, false))
val strArray = Literal.create(strSeq, ArrayType(StringType, false))

val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true))
val intwithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intWithNullArray?

val longwithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

longWithNullArray?


val nullArray = Literal.create(null, ArrayType(StringType, false))

checkEvaluation(CreateMapFromArrays(intArray, longArray), createMap(intSeq, longSeq))
checkEvaluation(CreateMapFromArrays(intArray, strArray), createMap(intSeq, strSeq))
checkEvaluation(CreateMapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq))

checkEvaluation(
CreateMapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq))
Expand All @@ -219,6 +221,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
CreateMapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq))
checkEvaluation(CreateMapFromArrays(nullArray, nullArray), null)

intercept[RuntimeException] {
checkEvaluation(CreateMapFromArrays(intwithNullArray, strArray), null)
}
intercept[RuntimeException] {
checkEvaluation(
CreateMapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v")
checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b"))))

val df3 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v")
intercept[AnalysisException] {
df3.select(map_from_arrays($"k", $"v"))
}
val df3 = Seq((null, null)).toDF("k", "v")
checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null)))

val df4 = Seq((1, "a")).toDF("k", "v")
intercept[AnalysisException] {
df4.select(map_from_arrays($"k", $"v"))
}

val df5 = Seq((null, null)).toDF("k", "v")
intercept[AnalysisException] {
df5.select(map_from_arrays($"k", $"v"))
val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v")
intercept[RuntimeException] {
df5.select(map_from_arrays($"k", $"v")).collect
}

val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v")
Expand Down