Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove parameters containing element data types
Replace key/value input data with map input data
Remove null check for keys
Remove unneeded code for sequences
  • Loading branch information
michalsenkyr committed Jun 4, 2017
commit 25ec2f0ca09f63d214e932af29371ebd2f81f840
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -335,12 +335,8 @@ object ScalaReflection extends ScalaReflection {

CollectObjectsToMap(
p => deserializerFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
returnNullable = false),
schemaFor(keyType).dataType,
p => deserializerFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
schemaFor(valueType).dataType,
getPath,
mirror.runtimeClass(t.typeSymbol.asClass)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,34 +660,28 @@ object CollectObjectsToMap {
* Construct an instance of CollectObjects case class.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CollectObjects -> CollectObjectsToMap

*
* @param keyFunction The function applied on the key collection elements.
* @param keyInputData An expression that when evaluated returns a key collection object.
* @param keyElementType The data type of key elements in the collection.
* @param valueFunction The function applied on the value collection elements.
* @param valueInputData An expression that when evaluated returns a value collection object.
* @param valueElementType The data type of value elements in the collection.
* @param inputData An expression that when evaluated returns a map object.
* @param collClass The type of the resulting collection.
*/
def apply(
keyFunction: Expression => Expression,
keyInputData: Expression,
keyElementType: DataType,
valueFunction: Expression => Expression,
valueInputData: Expression,
valueElementType: DataType,
inputData: Expression,
collClass: Class[_]): CollectObjectsToMap = {
val id = curId.getAndIncrement()
val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id"
val keyLoopIsNull = s"CollectObjectsToMap_keyLoopIsNull$id"
val keyLoopVar = LambdaVariable(keyLoopValue, keyLoopIsNull, keyElementType)
val mapType = inputData.dataType.asInstanceOf[MapType]
val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false)
val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id"
val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id"
val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, valueElementType)
val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType)
val tupleLoopVar = s"CollectObjectsToMap_tupleLoopValue$id"
val builderValue = s"CollectObjectsToMap_builderValue$id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generate name for keyLoopVar and valueLoopVar here because they are used in the keyFunction and valueFunction. The tupleLoopVar and builderValue don't have this problem and we can generate them in class CollectObjectsToMap

CollectObjectsToMap(
keyLoopValue, keyLoopIsNull, keyElementType, keyFunction(keyLoopVar), keyInputData,
valueLoopValue, valueLoopIsNull, valueElementType, valueFunction(valueLoopVar),
valueInputData,
keyLoopValue, keyFunction(keyLoopVar),
valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar),
inputData,
tupleLoopVar, collClass, builderValue)
}
}
Expand All @@ -699,123 +693,73 @@ object CollectObjectsToMap {
*
* @param keyLoopValue the name of the loop variable that is used when iterating over the key
* collection, and which is used as input for the `keyLambdaFunction`
* @param keyLoopIsNull the nullability of the loop variable that is used when iterating over
* the key collection, and which is used as input for the `keyLambdaFunction`
* @param keyLoopVarDataType the data type of the loop variable that is used when iterating over
* the key collection, and which is used as input for the
* `keyLambdaFunction`
* @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as
* a lambda function to handle collection elements.
* @param keyInputData An expression that when evaluated returns a collection object.
* @param valueLoopValue the name of the loop variable that is used when iterating over the value
* collection, and which is used as input for the `valueLambdaFunction`
* @param valueLoopIsNull the nullability of the loop variable that is used when iterating over
* the value collection, and which is used as input for the
* `valueLambdaFunction`
* @param valueLoopVarDataType the data type of the loop variable that is used when iterating over
* the value collection, and which is used as input for the
* `valueLambdaFunction`
* @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as
* a lambda function to handle collection elements.
* @param valueInputData An expression that when evaluated returns a collection object.
* @param inputData An expression that when evaluated returns a map object.
* @param tupleLoopValue the name of the loop variable that holds the tuple to be added to the
* resulting map (used only for Scala Map)
* @param collClass The type of the resulting collection.
* @param builderValue The name of the builder variable used to construct the resulting collection.
*/
case class CollectObjectsToMap private(
keyLoopValue: String,
keyLoopIsNull: String,
keyLoopVarDataType: DataType,
keyLambdaFunction: Expression,
keyInputData: Expression,
valueLoopValue: String,
valueLoopIsNull: String,
valueLoopVarDataType: DataType,
valueLambdaFunction: Expression,
valueInputData: Expression,
inputData: Expression,
tupleLoopValue: String,
collClass: Class[_],
builderValue: String) extends Expression with NonSQLExpression {

override def nullable: Boolean = keyInputData.nullable
override def nullable: Boolean = inputData.nullable

override def children: Seq[Expression] =
keyLambdaFunction :: keyInputData :: valueLambdaFunction :: valueInputData :: Nil
keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def dataType: DataType = ObjectType(collClass)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val keyElementJavaType = ctx.javaType(keyLoopVarDataType)
ctx.addMutableState("boolean", keyLoopIsNull, "")
val mapType = inputData.dataType.asInstanceOf[MapType]
val keyElementJavaType = ctx.javaType(mapType.keyType)
ctx.addMutableState(keyElementJavaType, keyLoopValue, "")
val genKeyInputData = keyInputData.genCode(ctx)
val genKeyFunction = keyLambdaFunction.genCode(ctx)
val valueElementJavaType = ctx.javaType(valueLoopVarDataType)
val valueElementJavaType = ctx.javaType(mapType.valueType)
ctx.addMutableState("boolean", valueLoopIsNull, "")
ctx.addMutableState(valueElementJavaType, valueLoopValue, "")
val genValueInputData = valueInputData.genCode(ctx)
val genValueFunction = valueLambdaFunction.genCode(ctx)
val genInputData = inputData.genCode(ctx)
val dataLength = ctx.freshName("dataLength")
val loopIndex = ctx.freshName("loopIndex")

// In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
// of input collection at runtime for this case.
val keySeq = ctx.freshName("keySeq")
val keyArray = ctx.freshName("keyArray")
val valueSeq = ctx.freshName("valueSeq")
val valueArray = ctx.freshName("valueArray")
def determineCollectionType(inputData: Expression, genInputData: ExprCode,
elementJavaType: String, seq: String, array: String) =
inputData.dataType match {
case ObjectType(cls) if cls == classOf[Object] =>
val seqClass = classOf[Seq[_]].getName
s"""
$seqClass $seq = null;
$elementJavaType[] $array = null;
if (${genInputData.value}.getClass().isArray()) {
$array = ($elementJavaType[]) ${genInputData.value};
} else {
$seq = ($seqClass) ${genInputData.value};
}
"""
case _ => ""
}
val determineKeyCollectionType = determineCollectionType(
keyInputData, genKeyInputData, keyElementJavaType, keySeq, keyArray)
val determineValueCollectionType = determineCollectionType(
valueInputData, genValueInputData, valueElementJavaType, valueSeq, valueArray)

// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
def inputDataType(inputData: Expression) = inputData.dataType match {
def inputDataType(dataType: DataType) = dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code in MapObejcts is:

    val inputDataType = inputData.dataType match {
      case p: PythonUserDefinedType => p.sqlType
      case _ => inputData.dataType
    }

We should call this before we do val mapType = inputData.dataType.asInstanceOf[MapType]

case p: PythonUserDefinedType => p.sqlType
case _ => inputData.dataType
case _ => dataType
}
val keyInputDataType = inputDataType(keyInputData)
val valueInputDataType = inputDataType(valueInputData)

def lengthAndLoopVar(inputDataType: DataType, genInputData: ExprCode,
seq: String, array: String) =
inputDataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
case ObjectType(cls) if cls.isArray =>
s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]"
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)"
case ArrayType(et, _) =>
s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex)
case ObjectType(cls) if cls == classOf[Object] =>
s"$seq == null ? $array.length : $seq.size()" ->
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
}

def lengthAndLoopVar(elementType: DataType, genInputData: ExprCode, method: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just 2 lines method, can we inline it?

array: String) =
s"${genInputData.value}.$method().numElements()" ->
ctx.getValue(s"${genInputData.value}.$method()", elementType, loopIndex)

val ((getKeyLength, getKeyLoopVar), (getValueLength, getValueLoopVar)) = (
lengthAndLoopVar(inputDataType(keyInputData), genKeyInputData, keySeq, keyArray),
lengthAndLoopVar(inputDataType(valueInputData), genValueInputData, valueSeq, valueArray)
lengthAndLoopVar(inputDataType(mapType.keyType), genInputData, "keyArray", keyArray),
lengthAndLoopVar(inputDataType(mapType.valueType), genInputData, "valueArray", valueArray)
)

// Make a copy of the data if it's unsafe-backed
Expand All @@ -831,19 +775,8 @@ case class CollectObjectsToMap private(
val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction)
val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)

def loopNullCheck(genInputData: ExprCode, inputDataType: DataType,
loopIsNull: String, loopValue: String) =
inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
s"$loopIsNull = false"
case _ => s"$loopIsNull = $loopValue == null;"
}
val keyLoopNullCheck =
loopNullCheck(genKeyInputData, keyInputDataType, keyLoopIsNull, keyLoopValue)
val valueLoopNullCheck =
loopNullCheck(genValueInputData, valueInputDataType, valueLoopIsNull, valueLoopValue)
s"$valueLoopIsNull = ${genInputData.value}.valueArray().isNullAt($loopIndex);"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about $valueArray.isNullAt($loopIndex)?


val constructBuilder = collClass match {
// Scala Map
Expand Down Expand Up @@ -873,7 +806,6 @@ case class CollectObjectsToMap private(
s"${collClass.getName} $builderValue = new $builderClass();"
// Java concrete Map implementation
case cls =>
val builderClass = classOf[java.util.Map[_, _]].getName
// Check for constructor with capacity specification
if (Try(cls.getConstructor(Integer.TYPE)).isSuccess) {
s"${collClass.getName} $builderValue = new ${cls.getName}($dataLength);"
Expand Down Expand Up @@ -902,18 +834,10 @@ case class CollectObjectsToMap private(
}

val code = s"""
${genKeyInputData.code}
${genValueInputData.code}
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};

if ((${genKeyInputData.isNull} && !${genValueInputData.isNull}) ||
(!${genKeyInputData.isNull} && ${genValueInputData.isNull})) {
throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
}

if (!${genKeyInputData.isNull}) {
$determineKeyCollectionType
$determineValueCollectionType
if (!${genInputData.isNull}) {
if ($getKeyLength != $getValueLength) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need a keyLength and valueLength, just have a mapLength which can be calculated by MapData.numElements

throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
}
Expand All @@ -924,16 +848,11 @@ case class CollectObjectsToMap private(
while ($loopIndex < $dataLength) {
$keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar);
$valueLoopValue = ($valueElementJavaType) ($getValueLoopVar);
$keyLoopNullCheck
$valueLoopNullCheck
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also inline this. The principle is, we should inline these simple codes as many as possible, then when you look at this code block, it's more clear what's going on.


${genKeyFunction.code}
${genValueFunction.code}

if (${genKeyFunction.isNull}) {
throw new RuntimeException("Found null in map key!");
}

$appendToBuilder

$loopIndex += 1;
Expand All @@ -942,7 +861,7 @@ case class CollectObjectsToMap private(
$getBuilderResult
}
"""
ev.copy(code = code, isNull = genKeyInputData.isNull)
ev.copy(code = code, isNull = genInputData.isNull)
}
}

Expand Down