@@ -430,20 +430,21 @@ object MapObjects {
430430 * @param function The function applied on the collection elements.
431431 * @param inputData An expression that when evaluated returns a collection object.
432432 * @param elementType The data type of elements in the collection.
433- * @param collClass The class of the resulting collection
433+ * @param customCollectionCls Class of the resulting collection (returning ObjectType)
434+ * or None (returning ArrayType)
434435 */
435436 def apply (
436437 function : Expression => Expression ,
437438 inputData : Expression ,
438439 elementType : DataType ,
439- collClass : Class [_] = classOf [ Array [_]] ): MapObjects = {
440+ customCollectionCls : Option [ Class [_]] = None ): MapObjects = {
440441 val id = curId.getAndIncrement()
441442 val loopValue = s " MapObjects_loopValue $id"
442443 val loopIsNull = s " MapObjects_loopIsNull $id"
443444 val loopVar = LambdaVariable (loopValue, loopIsNull, elementType)
444445 val builderValue = s " MapObjects_builderValue $id"
445446 MapObjects (loopValue, loopIsNull, elementType, function(loopVar), inputData,
446- collClass , builderValue)
447+ customCollectionCls , builderValue)
447448 }
448449}
449450
@@ -453,8 +454,8 @@ object MapObjects {
453454 * function is expressed using catalyst expressions.
454455 *
455456 * The type of the result is determined as follows:
456- * - ArrayType - when collClass is an array class
457- * - ObjectType(collClass ) - when collClass is a collection class
457+ * - ArrayType - when customCollectionCls is None
458+ * - ObjectType(collection ) - when customCollectionCls contains a collection class
458459 *
459460 * The following collection ObjectTypes are currently supported on input:
460461 * Seq, Array, ArrayData, java.util.List
@@ -468,7 +469,8 @@ object MapObjects {
468469 * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
469470 * to handle collection elements.
470471 * @param inputData An expression that when evaluated returns a collection object.
471- * @param collClass The class of the resulting collection
472+ * @param customCollectionCls Class of the resulting collection (returning ObjectType)
473+ * or None (returning ArrayType)
472474 * @param builderValue The name of the builder variable used to construct the resulting collection
473475 * (used only when returning ObjectType)
474476 */
@@ -478,7 +480,7 @@ case class MapObjects private(
478480 loopVarDataType : DataType ,
479481 lambdaFunction : Expression ,
480482 inputData : Expression ,
481- collClass : Class [_],
483+ customCollectionCls : Option [ Class [_] ],
482484 builderValue : String ) extends Expression with NonSQLExpression {
483485
484486 override def nullable : Boolean = inputData.nullable
@@ -489,8 +491,8 @@ case class MapObjects private(
489491 throw new UnsupportedOperationException (" Only code-generated evaluation is supported" )
490492
491493 override def dataType : DataType =
492- if ( ! collClass.isArray) ObjectType (collClass)
493- else ArrayType (lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
494+ customCollectionCls.map( ObjectType .apply).getOrElse(
495+ ArrayType (lambdaFunction.dataType, containsNull = lambdaFunction.nullable) )
494496
495497 override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
496498 val elementJavaType = ctx.javaType(loopVarDataType)
@@ -573,22 +575,23 @@ case class MapObjects private(
573575 case _ => s " $loopIsNull = $loopValue == null; "
574576 }
575577
576- val (genInit, genAssign, genResult): (String , String => String , String ) =
577- if (collClass.isArray) {
578- // array
579- (s """ $convertedType[] $convertedArray = null;
580- $convertedArray = $arrayConstructor; """ ,
581- genValue => s " $convertedArray[ $loopIndex] = $genValue; " ,
582- s " new ${classOf [GenericArrayData ].getName}( $convertedArray); " )
583- } else {
584- // collection
585- val collObjectName = s " ${collClass.getName}$$ .MODULE $$ "
586- val getBuilderVar = s " $collObjectName.newBuilder() "
578+ val (initCollection, addElement, getResult): (String , String => String , String ) =
579+ customCollectionCls match {
580+ case Some (cls) =>
581+ // collection
582+ val collObjectName = s " ${cls.getName}$$ .MODULE $$ "
583+ val getBuilderVar = s " $collObjectName.newBuilder() "
587584
588- (s """ ${classOf [Builder [_, _]].getName} $builderValue = $getBuilderVar;
585+ (s """ ${classOf [Builder [_, _]].getName} $builderValue = $getBuilderVar;
589586 $builderValue.sizeHint( $dataLength); """ ,
590- genValue => s " $builderValue. $$ plus $$ eq( $genValue); " ,
591- s " ( ${collClass.getName}) $builderValue.result(); " )
587+ genValue => s " $builderValue. $$ plus $$ eq( $genValue); " ,
588+ s " ( ${cls.getName}) $builderValue.result(); " )
589+ case None =>
590+ // array
591+ (s """ $convertedType[] $convertedArray = null;
592+ $convertedArray = $arrayConstructor; """ ,
593+ genValue => s " $convertedArray[ $loopIndex] = $genValue; " ,
594+ s " new ${classOf [GenericArrayData ].getName}( $convertedArray); " )
592595 }
593596
594597 val code = s """
@@ -598,7 +601,7 @@ case class MapObjects private(
598601 if (! ${genInputData.isNull}) {
599602 $determineCollectionType
600603 int $dataLength = $getLength;
601- $genInit
604+ $initCollection
602605
603606 int $loopIndex = 0;
604607 while ( $loopIndex < $dataLength) {
@@ -607,15 +610,15 @@ case class MapObjects private(
607610
608611 ${genFunction.code}
609612 if ( ${genFunction.isNull}) {
610- ${genAssign (" null" )}
613+ ${addElement (" null" )}
611614 } else {
612- ${genAssign (genFunctionValue)}
615+ ${addElement (genFunctionValue)}
613616 }
614617
615618 $loopIndex += 1;
616619 }
617620
618- ${ev.value} = $genResult
621+ ${ev.value} = $getResult
619622 }
620623 """
621624 ev.copy(code = code, isNull = genInputData.isNull)
0 commit comments