@@ -849,22 +849,15 @@ case class FormatNumber(x: Expression, d: Expression)
849849 @ transient
850850 private val numberFormat : DecimalFormat = new DecimalFormat (" " )
851851
852- override def eval (input : InternalRow ): Any = {
853- val xObject = x.eval(input)
854- if (xObject == null ) {
855- return null
856- }
857-
858- val dObject = d.eval(input)
859-
860- if (dObject == null || dObject.asInstanceOf [Int ] < 0 ) {
852+ override protected def nullSafeEval (xObject : Any , dObject : Any ): Any = {
853+ val dValue = dObject.asInstanceOf [Int ]
854+ if (dValue < 0 ) {
861855 return null
862856 }
863- val dValue = dObject.asInstanceOf [Int ]
864857
865858 if (dValue != lastDValue) {
866859 // construct a new DecimalFormat only if a new dValue
867- pattern.delete(0 , pattern.length() )
860+ pattern.delete(0 , pattern.length)
868861 pattern.append(" #,###,###,###,###,###,##0" )
869862
870863 // decimal place
@@ -877,9 +870,10 @@ case class FormatNumber(x: Expression, d: Expression)
877870 pattern.append(" 0" )
878871 }
879872 }
880- val dFormat = new DecimalFormat (pattern.toString())
881- lastDValue = dValue;
882- numberFormat.applyPattern(dFormat.toPattern())
873+ val dFormat = new DecimalFormat (pattern.toString)
874+ lastDValue = dValue
875+
876+ numberFormat.applyPattern(dFormat.toPattern)
883877 }
884878
885879 x.dataType match {
@@ -894,6 +888,52 @@ case class FormatNumber(x: Expression, d: Expression)
894888 }
895889 }
896890
891+ override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
892+ nullSafeCodeGen(ctx, ev, (num, d) => {
893+
894+ def typeHelper (p : String ): String = {
895+ x.dataType match {
896+ case _ : DecimalType => s """ $p.toJavaBigDecimal() """
897+ case _ => s " $p"
898+ }
899+ }
900+
901+ val sb = classOf [StringBuffer ].getName
902+ val df = classOf [DecimalFormat ].getName
903+ val lastDValue = ctx.freshName(" lastDValue" )
904+ val pattern = ctx.freshName(" pattern" )
905+ val numberFormat = ctx.freshName(" numberFormat" )
906+ val i = ctx.freshName(" i" )
907+ val dFormat = ctx.freshName(" dFormat" )
908+ ctx.addMutableState(" int" , lastDValue, s " $lastDValue = -100; " )
909+ ctx.addMutableState(sb, pattern, s " $pattern = new $sb(); " )
910+ ctx.addMutableState(df, numberFormat, s """ $numberFormat = new $df(""); """ )
911+
912+ s """
913+ if ( $d >= 0) {
914+ $pattern.delete(0, $pattern.length());
915+ if ( $d != $lastDValue) {
916+ $pattern.append("#,###,###,###,###,###,##0");
917+
918+ if ( $d > 0) {
919+ $pattern.append(".");
920+ for (int $i = 0; $i < $d; $i++) {
921+ $pattern.append("0");
922+ }
923+ }
924+ $df $dFormat = new $df( $pattern.toString());
925+ $lastDValue = $d;
926+ $numberFormat.applyPattern( $dFormat.toPattern());
927+ ${ev.primitive} = UTF8String.fromString( $numberFormat.format( ${typeHelper(num)}));
928+ }
929+ } else {
930+ ${ev.primitive} = null;
931+ ${ev.isNull} = true;
932+ }
933+ """
934+ })
935+ }
936+
897937 override def prettyName : String = " format_number"
898938}
899939
0 commit comments