-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24305][SQL][FOLLOWUP] Avoid serialization of private fields in collection expressions. #21352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-24305][SQL][FOLLOWUP] Avoid serialization of private fields in collection expressions. #21352
Changes from 10 commits
ded67f5
e96962e
f6368b5
2862d3e
a4d1e7f
62c55ad
294ac69
94b86a2
a0abc25
872ef99
fd3a945
922d2f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,27 +168,23 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI | |
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) | ||
|
|
||
| override def dataType: DataType = ArrayType(mountSchema) | ||
|
|
||
| override def nullable: Boolean = children.exists(_.nullable) | ||
|
|
||
| private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) | ||
|
|
||
| private lazy val arrayElementTypes = arrayTypes.map(_.elementType) | ||
|
|
||
| @transient private lazy val mountSchema: StructType = { | ||
| @transient override lazy val dataType: DataType = { | ||
| val fields = children.zip(arrayElementTypes).zipWithIndex.map { | ||
| case ((expr: NamedExpression, elementType), _) => | ||
| StructField(expr.name, elementType, nullable = true) | ||
| case ((_, elementType), idx) => | ||
| StructField(idx.toString, elementType, nullable = true) | ||
| } | ||
| StructType(fields) | ||
| ArrayType(StructType(fields), containsNull = false) | ||
| } | ||
|
|
||
| @transient lazy val numberOfArrays: Int = children.length | ||
| override def nullable: Boolean = children.exists(_.nullable) | ||
|
|
||
| @transient private lazy val arrayElementTypes = { | ||
| children.map(_.dataType.asInstanceOf[ArrayType].elementType) | ||
| } | ||
|
|
||
| @transient lazy val genericArrayData = classOf[GenericArrayData].getName | ||
| private def genericArrayData = classOf[GenericArrayData].getName | ||
|
|
||
| def emptyInputGenCode(ev: ExprCode): ExprCode = { | ||
| ev.copy(code""" | ||
|
|
@@ -256,7 +252,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI | |
| ("ArrayData[]", arrVals) :: Nil) | ||
|
|
||
| val initVariables = s""" | ||
| |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; | ||
| |ArrayData[] $arrVals = new ArrayData[${children.length}]; | ||
| |int $biggestCardinality = 0; | ||
| |${CodeGenerator.javaType(dataType)} ${ev.value} = null; | ||
| """.stripMargin | ||
|
|
@@ -268,7 +264,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI | |
| |if (!${ev.isNull}) { | ||
| | Object[] $args = new Object[$biggestCardinality]; | ||
| | for (int $i = 0; $i < $biggestCardinality; $i ++) { | ||
| | Object[] $currentRow = new Object[$numberOfArrays]; | ||
| | Object[] $currentRow = new Object[${children.length}]; | ||
| | $getValueForTypeSplitted | ||
| | $args[$i] = new $genericInternalRow($currentRow); | ||
| | } | ||
|
|
@@ -278,7 +274,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI | |
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| if (numberOfArrays == 0) { | ||
| if (children.length == 0) { | ||
| emptyInputGenCode(ev) | ||
| } else { | ||
| nonEmptyInputGenCode(ctx, ev) | ||
|
|
@@ -360,7 +356,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp | |
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(MapType) | ||
|
|
||
| lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] | ||
| private def childDataType: MapType = child.dataType.asInstanceOf[MapType] | ||
|
||
|
|
||
| override def dataType: DataType = { | ||
| ArrayType( | ||
|
|
@@ -520,7 +516,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres | |
| } | ||
| } | ||
|
|
||
| override def dataType: MapType = { | ||
| @transient override lazy val dataType: MapType = { | ||
| if (children.isEmpty) { | ||
| MapType(StringType, StringType) | ||
| } else { | ||
|
|
@@ -737,21 +733,22 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres | |
| since = "2.4.0") | ||
| case class MapFromEntries(child: Expression) extends UnaryExpression { | ||
|
|
||
| @transient | ||
| private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { | ||
| case ArrayType( | ||
| StructType(Array( | ||
| StructField(_, keyType, keyNullable, _), | ||
| StructField(_, valueType, valueNullable, _))), | ||
| containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) | ||
| case _ => None | ||
| @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = { | ||
|
||
| child.dataType match { | ||
| case ArrayType( | ||
| StructType(Array( | ||
| StructField(_, kt, kn, _), | ||
|
||
| StructField(_, vt, vn, _))), | ||
| cn) => Some((MapType(kt, vt, vn), kn, cn)) | ||
| case _ => None | ||
| } | ||
| } | ||
|
|
||
| private def nullEntries: Boolean = dataTypeDetails.get._3 | ||
| @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3 | ||
|
|
||
| override def nullable: Boolean = child.nullable || nullEntries | ||
|
|
||
| override def dataType: MapType = dataTypeDetails.get._1 | ||
| @transient override lazy val dataType: MapType = dataTypeDetails.get._1 | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { | ||
| case Some(_) => TypeCheckResult.TypeCheckSuccess | ||
|
|
@@ -949,8 +946,7 @@ trait ArraySortLike extends ExpectsInputTypes { | |
|
|
||
| protected def nullOrder: NullOrder | ||
|
|
||
| @transient | ||
| private lazy val lt: Comparator[Any] = { | ||
| @transient private lazy val lt: Comparator[Any] = { | ||
| val ordering = arrayExpression.dataType match { | ||
| case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] | ||
| case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] | ||
|
|
@@ -972,8 +968,7 @@ trait ArraySortLike extends ExpectsInputTypes { | |
| } | ||
| } | ||
|
|
||
| @transient | ||
| private lazy val gt: Comparator[Any] = { | ||
| @transient private lazy val gt: Comparator[Any] = { | ||
| val ordering = arrayExpression.dataType match { | ||
| case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] | ||
| case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] | ||
|
|
@@ -995,7 +990,10 @@ trait ArraySortLike extends ExpectsInputTypes { | |
| } | ||
| } | ||
|
|
||
| def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType | ||
| @transient lazy val elementType: DataType = { | ||
| arrayExpression.dataType.asInstanceOf[ArrayType].elementType | ||
| } | ||
|
|
||
| def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull | ||
|
|
||
| def sortEval(array: Any, ascending: Boolean): Any = { | ||
|
|
@@ -1211,7 +1209,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI | |
|
|
||
| override def dataType: DataType = child.dataType | ||
|
|
||
| lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType | ||
| @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType | ||
|
|
||
| override def nullSafeEval(input: Any): Any = input match { | ||
| case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) | ||
|
|
@@ -1603,7 +1601,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) | |
|
|
||
| override def children: Seq[Expression] = Seq(x, start, length) | ||
|
|
||
| lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType | ||
| @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType | ||
|
|
||
| override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { | ||
| val startInt = startVal.asInstanceOf[Int] | ||
|
|
@@ -1889,7 +1887,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast | |
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) | ||
|
|
||
| private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) | ||
| @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| val typeCheckResult = super.checkInputDataTypes() | ||
|
|
@@ -1930,7 +1928,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast | |
| min | ||
| } | ||
|
|
||
| override def dataType: DataType = child.dataType match { | ||
| @transient override lazy val dataType: DataType = child.dataType match { | ||
| case ArrayType(dt, _) => dt | ||
| case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") | ||
| } | ||
|
|
@@ -1954,7 +1952,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast | |
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) | ||
|
|
||
| private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) | ||
| @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| val typeCheckResult = super.checkInputDataTypes() | ||
|
|
@@ -1995,7 +1993,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast | |
| max | ||
| } | ||
|
|
||
| override def dataType: DataType = child.dataType match { | ||
| @transient override lazy val dataType: DataType = child.dataType match { | ||
| case ArrayType(dt, _) => dt | ||
| case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") | ||
| } | ||
|
|
@@ -2100,7 +2098,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti | |
| @transient private lazy val ordering: Ordering[Any] = | ||
| TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) | ||
|
|
||
| override def dataType: DataType = left.dataType match { | ||
| @transient override lazy val dataType: DataType = left.dataType match { | ||
| case ArrayType(elementType, _) => elementType | ||
| case MapType(_, valueType, _) => valueType | ||
| } | ||
|
|
@@ -2209,9 +2207,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti | |
| """) | ||
| case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { | ||
|
|
||
| private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH | ||
|
|
||
| val allowedTypes = Seq(StringType, BinaryType, ArrayType) | ||
| private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| if (children.isEmpty) { | ||
|
|
@@ -2228,15 +2224,15 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio | |
| } | ||
| } | ||
|
|
||
| override def dataType: DataType = { | ||
| @transient override lazy val dataType: DataType = { | ||
| if (children.isEmpty) { | ||
| StringType | ||
| } else { | ||
| super.dataType | ||
| } | ||
| } | ||
|
|
||
| lazy val javaType: String = CodeGenerator.javaType(dataType) | ||
| private def javaType: String = CodeGenerator.javaType(dataType) | ||
|
|
||
| override def nullable: Boolean = children.exists(_.nullable) | ||
|
|
||
|
|
@@ -2256,9 +2252,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio | |
| } else { | ||
| val arrayData = inputs.map(_.asInstanceOf[ArrayData]) | ||
| val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) | ||
| if (numberOfElements > MAX_ARRAY_LENGTH) { | ||
| if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { | ||
| throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + | ||
| s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") | ||
| " elements due to exceeding the array size limit " + | ||
| ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") | ||
| } | ||
| val finalData = new Array[AnyRef](numberOfElements.toInt) | ||
| var position = 0 | ||
|
|
@@ -2316,9 +2313,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio | |
| |for (int z = 0; z < ${children.length}; z++) { | ||
| | $numElements += args[z].numElements(); | ||
| |} | ||
| |if ($numElements > $MAX_ARRAY_LENGTH) { | ||
| |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | ||
| | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + | ||
| | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); | ||
| | " elements due to exceeding the array size limit" + | ||
| | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); | ||
| |} | ||
| """.stripMargin | ||
|
|
||
|
|
@@ -2413,15 +2411,13 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio | |
| since = "2.4.0") | ||
| case class Flatten(child: Expression) extends UnaryExpression { | ||
|
|
||
| private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH | ||
|
|
||
| private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] | ||
| private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] | ||
|
|
||
| override def nullable: Boolean = child.nullable || childDataType.containsNull | ||
|
|
||
| override def dataType: DataType = childDataType.elementType | ||
| @transient override lazy val dataType: DataType = childDataType.elementType | ||
|
|
||
| lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType | ||
| @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = child.dataType match { | ||
| case ArrayType(_: ArrayType, _) => | ||
|
|
@@ -2441,9 +2437,10 @@ case class Flatten(child: Expression) extends UnaryExpression { | |
| } else { | ||
| val arrayData = elements.map(_.asInstanceOf[ArrayData]) | ||
| val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) | ||
| if (numberOfElements > MAX_ARRAY_LENGTH) { | ||
| if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { | ||
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + | ||
| s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") | ||
| s"$numberOfElements elements due to exceeding the array size limit " + | ||
| ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") | ||
| } | ||
| val flattenedData = new Array(numberOfElements.toInt) | ||
| var position = 0 | ||
|
|
@@ -2476,9 +2473,10 @@ case class Flatten(child: Expression) extends UnaryExpression { | |
| |for (int z = 0; z < $childVariableName.numElements(); z++) { | ||
| | $variableName += $childVariableName.getArray(z).numElements(); | ||
| |} | ||
| |if ($variableName > $MAX_ARRAY_LENGTH) { | ||
| |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | ||
| | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + | ||
| | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); | ||
| | $variableName + " elements due to exceeding the array size limit" + | ||
| | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); | ||
| |} | ||
| """.stripMargin | ||
| (code, variableName) | ||
|
|
@@ -2602,7 +2600,7 @@ case class Sequence( | |
|
|
||
| override def nullable: Boolean = children.exists(_.nullable) | ||
|
|
||
| override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) | ||
| override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| val startType = start.dataType | ||
|
|
@@ -2633,7 +2631,7 @@ case class Sequence( | |
| stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), | ||
| timeZoneId) | ||
|
|
||
| private lazy val impl: SequenceImpl = dataType.elementType match { | ||
| @transient private lazy val impl: SequenceImpl = dataType.elementType match { | ||
| case iType: IntegralType => | ||
| type T = iType.InternalType | ||
| val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) | ||
|
|
@@ -2953,8 +2951,6 @@ object Sequence { | |
| case class ArrayRepeat(left: Expression, right: Expression) | ||
| extends BinaryExpression with ExpectsInputTypes { | ||
|
|
||
| private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH | ||
|
|
||
| override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) | ||
|
|
@@ -2966,9 +2962,9 @@ case class ArrayRepeat(left: Expression, right: Expression) | |
| if (count == null) { | ||
| null | ||
| } else { | ||
| if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { | ||
| if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { | ||
| throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + | ||
| s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); | ||
| s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); | ||
| } | ||
| val element = left.eval(input) | ||
| new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) | ||
|
|
@@ -3027,9 +3023,10 @@ case class ArrayRepeat(left: Expression, right: Expression) | |
| |if ($count > 0) { | ||
| | $numElements = $count; | ||
| |} | ||
| |if ($numElements > $MAX_ARRAY_LENGTH) { | ||
| |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | ||
| | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + | ||
| | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); | ||
| | " elements due to exceeding the array size limit" + | ||
| | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); | ||
| |} | ||
| """.stripMargin | ||
|
|
||
|
|
@@ -3111,7 +3108,7 @@ case class ArrayRemove(left: Expression, right: Expression) | |
| Seq(ArrayType, elementType) | ||
| } | ||
|
|
||
| lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType | ||
| private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType | ||
|
|
||
| @transient private lazy val ordering: Ordering[Any] = | ||
| TypeUtils.getInterpretedOrdering(right.dataType) | ||
|
|
@@ -3228,7 +3225,7 @@ case class ArrayDistinct(child: Expression) | |
|
|
||
| override def dataType: DataType = child.dataType | ||
|
|
||
| @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType | ||
| @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType | ||
|
|
||
| @transient private lazy val ordering: Ordering[Any] = | ||
| TypeUtils.getInterpretedOrdering(elementType) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we don't need the braces