Skip to content
Closed
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,27 +168,23 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI

override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)

override def dataType: DataType = ArrayType(mountSchema)

override def nullable: Boolean = children.exists(_.nullable)

private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])

private lazy val arrayElementTypes = arrayTypes.map(_.elementType)

@transient private lazy val mountSchema: StructType = {
@transient override lazy val dataType: DataType = {
val fields = children.zip(arrayElementTypes).zipWithIndex.map {
case ((expr: NamedExpression, elementType), _) =>
StructField(expr.name, elementType, nullable = true)
case ((_, elementType), idx) =>
StructField(idx.toString, elementType, nullable = true)
}
StructType(fields)
ArrayType(StructType(fields), containsNull = false)
}

@transient lazy val numberOfArrays: Int = children.length
override def nullable: Boolean = children.exists(_.nullable)

@transient private lazy val arrayElementTypes = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we don't need the braces

children.map(_.dataType.asInstanceOf[ArrayType].elementType)
}

@transient lazy val genericArrayData = classOf[GenericArrayData].getName
private def genericArrayData = classOf[GenericArrayData].getName

def emptyInputGenCode(ev: ExprCode): ExprCode = {
ev.copy(code"""
Expand Down Expand Up @@ -256,7 +252,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
("ArrayData[]", arrVals) :: Nil)

val initVariables = s"""
|ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
|ArrayData[] $arrVals = new ArrayData[${children.length}];
|int $biggestCardinality = 0;
|${CodeGenerator.javaType(dataType)} ${ev.value} = null;
""".stripMargin
Expand All @@ -268,7 +264,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
|if (!${ev.isNull}) {
| Object[] $args = new Object[$biggestCardinality];
| for (int $i = 0; $i < $biggestCardinality; $i ++) {
| Object[] $currentRow = new Object[$numberOfArrays];
| Object[] $currentRow = new Object[${children.length}];
| $getValueForTypeSplitted
| $args[$i] = new $genericInternalRow($currentRow);
| }
Expand All @@ -278,7 +274,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (numberOfArrays == 0) {
if (children.length == 0) {
emptyInputGenCode(ev)
} else {
nonEmptyInputGenCode(ctx, ev)
Expand Down Expand Up @@ -360,7 +356,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
private def childDataType: MapType = child.dataType.asInstanceOf[MapType]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a lazy val

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that one. Thanks!


override def dataType: DataType = {
ArrayType(
Expand Down Expand Up @@ -520,7 +516,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
}
}

override def dataType: MapType = {
@transient override lazy val dataType: MapType = {
if (children.isEmpty) {
MapType(StringType, StringType)
} else {
Expand Down Expand Up @@ -737,21 +733,22 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
since = "2.4.0")
case class MapFromEntries(child: Expression) extends UnaryExpression {

@transient
private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
case ArrayType(
StructType(Array(
StructField(_, keyType, keyNullable, _),
StructField(_, valueType, valueNullable, _))),
containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull))
case _ => None
@transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an unneeded change, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I wanted to be consistent in terms of formatting. (@transient to be on the same line as private lazy val dataTypeDetails) After the change, two lines were exceeding 100 characters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but this seems an unneeded change to me and I think there are other places where we use this syntax, so I see no reason to change it

child.dataType match {
case ArrayType(
StructType(Array(
StructField(_, kt, kn, _),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is there any reason to change variable names? It would be good to minimize differences for review and ease of understanding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the motivation is described here. I will revert this piece of code shortly.

StructField(_, vt, vn, _))),
cn) => Some((MapType(kt, vt, vn), kn, cn))
case _ => None
}
}

private def nullEntries: Boolean = dataTypeDetails.get._3
@transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3

override def nullable: Boolean = child.nullable || nullEntries

override def dataType: MapType = dataTypeDetails.get._1
@transient override lazy val dataType: MapType = dataTypeDetails.get._1

override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
case Some(_) => TypeCheckResult.TypeCheckSuccess
Expand Down Expand Up @@ -949,8 +946,7 @@ trait ArraySortLike extends ExpectsInputTypes {

protected def nullOrder: NullOrder

@transient
private lazy val lt: Comparator[Any] = {
@transient private lazy val lt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
Expand All @@ -972,8 +968,7 @@ trait ArraySortLike extends ExpectsInputTypes {
}
}

@transient
private lazy val gt: Comparator[Any] = {
@transient private lazy val gt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
Expand All @@ -995,7 +990,10 @@ trait ArraySortLike extends ExpectsInputTypes {
}
}

def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType
@transient lazy val elementType: DataType = {
arrayExpression.dataType.asInstanceOf[ArrayType].elementType
}

def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull

def sortEval(array: Any, ascending: Boolean): Any = {
Expand Down Expand Up @@ -1211,7 +1209,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI

override def dataType: DataType = child.dataType

lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
@transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def nullSafeEval(input: Any): Any = input match {
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
Expand Down Expand Up @@ -1603,7 +1601,7 @@ case class Slice(x: Expression, start: Expression, length: Expression)

override def children: Seq[Expression] = Seq(x, start, length)

lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
@transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType

override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
val startInt = startVal.asInstanceOf[Int]
Expand Down Expand Up @@ -1889,7 +1887,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
@transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
Expand Down Expand Up @@ -1930,7 +1928,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
min
}

override def dataType: DataType = child.dataType match {
@transient override lazy val dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}
Expand All @@ -1954,7 +1952,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
@transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
Expand Down Expand Up @@ -1995,7 +1993,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
max
}

override def dataType: DataType = child.dataType match {
@transient override lazy val dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}
Expand Down Expand Up @@ -2100,7 +2098,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType)

override def dataType: DataType = left.dataType match {
@transient override lazy val dataType: DataType = left.dataType match {
case ArrayType(elementType, _) => elementType
case MapType(_, valueType, _) => valueType
}
Expand Down Expand Up @@ -2209,9 +2207,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
""")
case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression {

private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

val allowedTypes = Seq(StringType, BinaryType, ArrayType)
private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
Expand All @@ -2228,15 +2224,15 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
}
}

override def dataType: DataType = {
@transient override lazy val dataType: DataType = {
if (children.isEmpty) {
StringType
} else {
super.dataType
}
}

lazy val javaType: String = CodeGenerator.javaType(dataType)
private def javaType: String = CodeGenerator.javaType(dataType)

override def nullable: Boolean = children.exists(_.nullable)

Expand All @@ -2256,9 +2252,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
} else {
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numberOfElements > MAX_ARRAY_LENGTH) {
if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
" elements due to exceeding the array size limit " +
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
}
val finalData = new Array[AnyRef](numberOfElements.toInt)
var position = 0
Expand Down Expand Up @@ -2316,9 +2313,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
|for (int z = 0; z < ${children.length}; z++) {
| $numElements += args[z].numElements();
|}
|if ($numElements > $MAX_ARRAY_LENGTH) {
|if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements +
| " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
| " elements due to exceeding the array size limit" +
| " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|}
""".stripMargin

Expand Down Expand Up @@ -2413,15 +2411,13 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
since = "2.4.0")
case class Flatten(child: Expression) extends UnaryExpression {

private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]

override def nullable: Boolean = child.nullable || childDataType.containsNull

override def dataType: DataType = childDataType.elementType
@transient override lazy val dataType: DataType = childDataType.elementType

lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
@transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(_: ArrayType, _) =>
Expand All @@ -2441,9 +2437,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
} else {
val arrayData = elements.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements())
if (numberOfElements > MAX_ARRAY_LENGTH) {
if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
s"$numberOfElements elements due to exceeding the array size limit " +
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
}
val flattenedData = new Array(numberOfElements.toInt)
var position = 0
Expand Down Expand Up @@ -2476,9 +2473,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
|for (int z = 0; z < $childVariableName.numElements(); z++) {
| $variableName += $childVariableName.getArray(z).numElements();
|}
|if ($variableName > $MAX_ARRAY_LENGTH) {
|if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
| $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
| $variableName + " elements due to exceeding the array size limit" +
| " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|}
""".stripMargin
(code, variableName)
Expand Down Expand Up @@ -2602,7 +2600,7 @@ case class Sequence(

override def nullable: Boolean = children.exists(_.nullable)

override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false)
override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false)

override def checkInputDataTypes(): TypeCheckResult = {
val startType = start.dataType
Expand Down Expand Up @@ -2633,7 +2631,7 @@ case class Sequence(
stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step),
timeZoneId)

private lazy val impl: SequenceImpl = dataType.elementType match {
@transient private lazy val impl: SequenceImpl = dataType.elementType match {
case iType: IntegralType =>
type T = iType.InternalType
val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
Expand Down Expand Up @@ -2953,8 +2951,6 @@ object Sequence {
case class ArrayRepeat(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {

private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)

override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
Expand All @@ -2966,9 +2962,9 @@ case class ArrayRepeat(left: Expression, right: Expression)
if (count == null) {
null
} else {
if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) {
if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
s"due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
}
val element = left.eval(input)
new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
Expand Down Expand Up @@ -3027,9 +3023,10 @@ case class ArrayRepeat(left: Expression, right: Expression)
|if ($count > 0) {
| $numElements = $count;
|}
|if ($numElements > $MAX_ARRAY_LENGTH) {
|if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
| " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
| " elements due to exceeding the array size limit" +
| " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|}
""".stripMargin

Expand Down Expand Up @@ -3111,7 +3108,7 @@ case class ArrayRemove(left: Expression, right: Expression)
Seq(ArrayType, elementType)
}

lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
Expand Down Expand Up @@ -3228,7 +3225,7 @@ case class ArrayDistinct(child: Expression)

override def dataType: DataType = child.dataType

@transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
@transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)
Expand Down