Skip to content

Commit d04e043

Browse files
committed
Type alias bug fix & changes based on code review
Dealias collection type before obtaining its companion object Change collClass to Option Rename variables
1 parent b5f87bd commit d04e043

File tree

2 files changed

+32
-29
lines changed

2 files changed

+32
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,11 @@ object ScalaReflection extends ScalaReflection {
307307
}
308308
}
309309

310-
val cls = t.companion.decl(TermName("newBuilder")) match {
310+
val cls = t.dealias.companion.decl(TermName("newBuilder")) match {
311311
case NoSymbol => classOf[Seq[_]]
312312
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
313313
}
314-
MapObjects(mapFunction, getPath, dataType, cls)
314+
MapObjects(mapFunction, getPath, dataType, Some(cls))
315315

316316
case t if t <:< localTypeOf[Map[_, _]] =>
317317
// TODO: add walked type path for map

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -430,20 +430,21 @@ object MapObjects {
430430
* @param function The function applied on the collection elements.
431431
* @param inputData An expression that when evaluated returns a collection object.
432432
* @param elementType The data type of elements in the collection.
433-
* @param collClass The class of the resulting collection
433+
* @param customCollectionCls Class of the resulting collection (returning ObjectType)
434+
* or None (returning ArrayType)
434435
*/
435436
def apply(
436437
function: Expression => Expression,
437438
inputData: Expression,
438439
elementType: DataType,
439-
collClass: Class[_] = classOf[Array[_]]): MapObjects = {
440+
customCollectionCls: Option[Class[_]] = None): MapObjects = {
440441
val id = curId.getAndIncrement()
441442
val loopValue = s"MapObjects_loopValue$id"
442443
val loopIsNull = s"MapObjects_loopIsNull$id"
443444
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
444445
val builderValue = s"MapObjects_builderValue$id"
445446
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
446-
collClass, builderValue)
447+
customCollectionCls, builderValue)
447448
}
448449
}
449450

@@ -453,8 +454,8 @@ object MapObjects {
453454
* function is expressed using catalyst expressions.
454455
*
455456
* The type of the result is determined as follows:
456-
* - ArrayType - when collClass is an array class
457-
* - ObjectType(collClass) - when collClass is a collection class
457+
* - ArrayType - when customCollectionCls is None
458+
* - ObjectType(collection) - when customCollectionCls contains a collection class
458459
*
459460
* The following collection ObjectTypes are currently supported on input:
460461
* Seq, Array, ArrayData, java.util.List
@@ -468,7 +469,8 @@ object MapObjects {
468469
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
469470
* to handle collection elements.
470471
* @param inputData An expression that when evaluated returns a collection object.
471-
* @param collClass The class of the resulting collection
472+
* @param customCollectionCls Class of the resulting collection (returning ObjectType)
473+
* or None (returning ArrayType)
472474
* @param builderValue The name of the builder variable used to construct the resulting collection
473475
* (used only when returning ObjectType)
474476
*/
@@ -478,7 +480,7 @@ case class MapObjects private(
478480
loopVarDataType: DataType,
479481
lambdaFunction: Expression,
480482
inputData: Expression,
481-
collClass: Class[_],
483+
customCollectionCls: Option[Class[_]],
482484
builderValue: String) extends Expression with NonSQLExpression {
483485

484486
override def nullable: Boolean = inputData.nullable
@@ -489,8 +491,8 @@ case class MapObjects private(
489491
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
490492

491493
override def dataType: DataType =
492-
if (!collClass.isArray) ObjectType(collClass)
493-
else ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
494+
customCollectionCls.map(ObjectType.apply).getOrElse(
495+
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable))
494496

495497
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
496498
val elementJavaType = ctx.javaType(loopVarDataType)
@@ -573,22 +575,23 @@ case class MapObjects private(
573575
case _ => s"$loopIsNull = $loopValue == null;"
574576
}
575577

576-
val (genInit, genAssign, genResult): (String, String => String, String) =
577-
if (collClass.isArray) {
578-
// array
579-
(s"""$convertedType[] $convertedArray = null;
580-
$convertedArray = $arrayConstructor;""",
581-
genValue => s"$convertedArray[$loopIndex] = $genValue;",
582-
s"new ${classOf[GenericArrayData].getName}($convertedArray);")
583-
} else {
584-
// collection
585-
val collObjectName = s"${collClass.getName}$$.MODULE$$"
586-
val getBuilderVar = s"$collObjectName.newBuilder()"
578+
val (initCollection, addElement, getResult): (String, String => String, String) =
579+
customCollectionCls match {
580+
case Some(cls) =>
581+
// collection
582+
val collObjectName = s"${cls.getName}$$.MODULE$$"
583+
val getBuilderVar = s"$collObjectName.newBuilder()"
587584

588-
(s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
585+
(s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
589586
$builderValue.sizeHint($dataLength);""",
590-
genValue => s"$builderValue.$$plus$$eq($genValue);",
591-
s"(${collClass.getName}) $builderValue.result();")
587+
genValue => s"$builderValue.$$plus$$eq($genValue);",
588+
s"(${cls.getName}) $builderValue.result();")
589+
case None =>
590+
// array
591+
(s"""$convertedType[] $convertedArray = null;
592+
$convertedArray = $arrayConstructor;""",
593+
genValue => s"$convertedArray[$loopIndex] = $genValue;",
594+
s"new ${classOf[GenericArrayData].getName}($convertedArray);")
592595
}
593596

594597
val code = s"""
@@ -598,7 +601,7 @@ case class MapObjects private(
598601
if (!${genInputData.isNull}) {
599602
$determineCollectionType
600603
int $dataLength = $getLength;
601-
$genInit
604+
$initCollection
602605

603606
int $loopIndex = 0;
604607
while ($loopIndex < $dataLength) {
@@ -607,15 +610,15 @@ case class MapObjects private(
607610

608611
${genFunction.code}
609612
if (${genFunction.isNull}) {
610-
${genAssign("null")}
613+
${addElement("null")}
611614
} else {
612-
${genAssign(genFunctionValue)}
615+
${addElement(genFunctionValue)}
613616
}
614617

615618
$loopIndex += 1;
616619
}
617620

618-
${ev.value} = $genResult
621+
${ev.value} = $getResult
619622
}
620623
"""
621624
ev.copy(code = code, isNull = genInputData.isNull)

0 commit comments

Comments
 (0)