Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[SPARK-23934][SQL] Handling of null entries
  • Loading branch information
mn-mikke committed Jun 21, 2018
commit 599656eed53222d5e243db663bf52cc3c1e802a7
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,36 @@ class CodegenContext {
}
}

/**
* Generates code to do null safe execution when accessing properties of complex
* ArrayData elements.
*
* @param nullElements used to decide whether the ArrayData might contain null or not.
* @param isNull a variable indicating whether the result will be evaluated to null or not.
* @param arrayData a variable name representing the ArrayData.
* @param execute the code that should be executed only if the ArrayData doesn't contain
* any null.
*/
def nullArrayElementsSaveExec(
nullElements: Boolean,
isNull: String,
arrayData: String)(
execute: String): String = {
val i = freshName("idx")
if (nullElements) {
s"""
|for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) {
| $isNull |= $arrayData.isNullAt($i);
|}
|if (!$isNull) {
| $execute
|}
""".stripMargin
} else {
execute
}
}

/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {

private def nullEntries: Boolean = dataTypeDetails.get._3

override def nullable: Boolean = child.nullable || nullEntries

override def dataType: MapType = dataTypeDetails.get._1

override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
Expand All @@ -510,71 +512,60 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {

override protected def nullSafeEval(input: Any): Any = {
val arrayData = input.asInstanceOf[ArrayData]
val length = arrayData.numElements()
val numEntries = if (nullEntries) (0 until length).count(!arrayData.isNullAt(_)) else length
val numEntries = arrayData.numElements()
var i = 0
if(nullEntries) {
while (i < numEntries) {
if (arrayData.isNullAt(i)) return null
i += 1
}
}
val keyArray = new Array[AnyRef](numEntries)
val valueArray = new Array[AnyRef](numEntries)
var i = 0
var j = 0
while (i < length) {
if (!arrayData.isNullAt(i)) {
val entry = arrayData.getStruct(i, 2)
val key = entry.get(0, dataType.keyType)
if (key == null) {
throw new RuntimeException("The first field from a struct (key) can't be null.")
}
keyArray.update(j, key)
val value = entry.get(1, dataType.valueType)
valueArray.update(j, value)
j += 1
i = 0
while (i < numEntries) {
val entry = arrayData.getStruct(i, 2)
val key = entry.get(0, dataType.keyType)
if (key == null) {
throw new RuntimeException("The first field from a struct (key) can't be null.")
}
keyArray.update(i, key)
val value = entry.get(1, dataType.valueType)
valueArray.update(i, value)
i += 1
}
ArrayBasedMapData(keyArray, valueArray)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val length = ctx.freshName("length")
val numEntries = ctx.freshName("numEntries")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
val code = if (isKeyPrimitive && isValuePrimitive) {
genCodeForPrimitiveElements(ctx, c, ev.value, length, numEntries)
genCodeForPrimitiveElements(ctx, c, ev.value, numEntries)
} else {
genCodeForAnyElements(ctx, c, ev.value, length, numEntries)
genCodeForAnyElements(ctx, c, ev.value, numEntries)
}
val numEntriesAssignment = if (nullEntries) {
val idx = ctx.freshName("idx")
ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) {
s"""
|int $numEntries = 0;
|for (int $idx = 0; $idx < $length; $idx++) {
| if (!$c.isNullAt($idx)) $numEntries++;
|}
|final int $numEntries = $c.numElements();
|$code
""".stripMargin
} else {
s"final int $numEntries = $length;"
}

s"""
|final int $length = $c.numElements();
|$numEntriesAssignment
|$code
""".stripMargin
})
}

private def genCodeForAssignmentLoop(
ctx: CodegenContext,
childVariable: String,
length: String,
mapData: String,
numEntries: String,
keyAssignment: (String, String) => String,
valueAssignment: (String, String) => String): String = {
val entry = ctx.freshName("entry")
val i = ctx.freshName("idx")
val j = ctx.freshName("idx")

val nullEntryCheck = if (nullEntries) s"if ($childVariable.isNullAt($i)) continue;" else ""
val nullKeyCheck = if (dataTypeDetails.get._2) {
s"""
|if ($entry.isNullAt(0)) {
Expand All @@ -586,13 +577,11 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
}

s"""
|for (int $i = 0, $j = 0; $i < $length; $i++) {
| $nullEntryCheck
|for (int $i = 0; $i < $numEntries; $i++) {
| InternalRow $entry = $childVariable.getStruct($i, 2);
| $nullKeyCheck
| ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), j)}
| ${valueAssignment(entry, j)}
| $j++;
| ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)}
| ${valueAssignment(entry, i)}
|}
""".stripMargin
}
Expand All @@ -601,7 +590,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
ctx: CodegenContext,
childVariable: String,
mapData: String,
length: String,
numEntries: String): String = {
val byteArraySize = ctx.freshName("byteArraySize")
val keySectionSize = ctx.freshName("keySectionSize")
Expand Down Expand Up @@ -638,7 +626,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
length,
mapData,
numEntries,
keyAssignment,
valueAssignment
)
Expand All @@ -648,7 +637,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
|final long $valueSectionSize = $vByteSize;
|final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| ${genCodeForAnyElements(ctx, childVariable, mapData, length, numEntries)}
| ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)}
|} else {
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeMapData $unsafeMapData = new UnsafeMapData();
Expand All @@ -668,7 +657,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
ctx: CodegenContext,
childVariable: String,
mapData: String,
length: String,
numEntries: String): String = {
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
Expand All @@ -687,7 +675,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
length,
mapData,
numEntries,
keyAssignment,
valueAssignment)

Expand Down Expand Up @@ -2218,24 +2207,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
} else {
genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
}
if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code
ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code)
})
}

private def nullElementsProtection(
ev: ExprCode,
childVariableName: String,
coreLogic: String): String = {
s"""
|for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) {
| ${ev.isNull} |= $childVariableName.isNullAt(z);
|}
|if (!${ev.isNull}) {
| $coreLogic
|}
""".stripMargin
}

private def genCodeForNumberOfElements(
ctx: CodegenContext,
childVariableName: String) : (String, String) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType)
val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType)
val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType)
val aby = Literal.create(Seq(r(1.toByte, 10.toByte)), arrayType(ByteType, ByteType))
val ash = Literal.create(Seq(r(1.toShort, 10.toShort)), arrayType(ShortType, ShortType))
val alo = Literal.create(Seq(r(1L, 10L)), arrayType(LongType, LongType))

checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20))
checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null))
Expand All @@ -111,10 +108,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkExceptionInExpression[RuntimeException](
MapFromEntries(ai5),
"The first field from a struct (key) can't be null.")
checkEvaluation(MapFromEntries(ai6), Map(2 -> 20))
checkEvaluation(MapFromEntries(aby), Map(1.toByte -> 10.toByte))
checkEvaluation(MapFromEntries(ash), Map(1.toShort -> 10.toShort))
checkEvaluation(MapFromEntries(alo), Map(1L -> 10L))
checkEvaluation(MapFromEntries(ai6), null)

// Non-primitive-type keys and values
val asType = arrayType(StringType, StringType)
Expand All @@ -134,7 +128,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkExceptionInExpression[RuntimeException](
MapFromEntries(as5),
"The first field from a struct (key) can't be null.")
checkEvaluation(MapFromEntries(as6), Map("b" -> "bb"))
checkEvaluation(MapFromEntries(as6), null)
}

test("Sort Array") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
).toDF("a")
val iExpected = Seq(
Row(Map(1 -> 10, 2 -> 20, 3 -> 10)),
Row(Map(1 -> 10, 2 -> 20)),
Row(null),
Row(Map.empty),
Row(null))

Expand All @@ -673,7 +673,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
).toDF("a")
val sExpected = Seq(
Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")),
Row(Map("a" -> "aa", "b" -> "bb")),
Row(null),
Row(Map("a" -> null, "b" -> null)),
Row(Map.empty),
Row(null))
Expand Down