@@ -902,22 +902,15 @@ case class FormatNumber(x: Expression, d: Expression)
902902 @ transient
903903 private val numberFormat : DecimalFormat = new DecimalFormat (" " )
904904
905- override def eval ( input : InternalRow ): Any = {
906- val xObject = x.eval(input)
907- if (xObject == null ) {
905+ override protected def nullSafeEval ( xObject : Any , dObject : Any ): Any = {
906+ val dValue = dObject. asInstanceOf [ Int ]
907+ if (dValue < 0 ) {
908908 return null
909909 }
910910
911- val dObject = d.eval(input)
912-
913- if (dObject == null || dObject.asInstanceOf [Int ] < 0 ) {
914- return null
915- }
916- val dValue = dObject.asInstanceOf [Int ]
917-
918911 if (dValue != lastDValue) {
919912 // construct a new DecimalFormat only if a new dValue
920- pattern.delete(0 , pattern.length() )
913+ pattern.delete(0 , pattern.length)
921914 pattern.append(" #,###,###,###,###,###,##0" )
922915
923916 // decimal place
@@ -930,9 +923,10 @@ case class FormatNumber(x: Expression, d: Expression)
930923 pattern.append(" 0" )
931924 }
932925 }
933- val dFormat = new DecimalFormat (pattern.toString())
934- lastDValue = dValue;
935- numberFormat.applyPattern(dFormat.toPattern())
926+ val dFormat = new DecimalFormat (pattern.toString)
927+ lastDValue = dValue
928+
929+ numberFormat.applyPattern(dFormat.toPattern)
936930 }
937931
938932 x.dataType match {
@@ -947,6 +941,52 @@ case class FormatNumber(x: Expression, d: Expression)
947941 }
948942 }
949943
944+ override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
945+ nullSafeCodeGen(ctx, ev, (num, d) => {
946+
947+ def typeHelper (p : String ): String = {
948+ x.dataType match {
949+ case _ : DecimalType => s """ $p.toJavaBigDecimal() """
950+ case _ => s " $p"
951+ }
952+ }
953+
954+ val sb = classOf [StringBuffer ].getName
955+ val df = classOf [DecimalFormat ].getName
956+ val lastDValue = ctx.freshName(" lastDValue" )
957+ val pattern = ctx.freshName(" pattern" )
958+ val numberFormat = ctx.freshName(" numberFormat" )
959+ val i = ctx.freshName(" i" )
960+ val dFormat = ctx.freshName(" dFormat" )
961+ ctx.addMutableState(" int" , lastDValue, s " $lastDValue = -100; " )
962+ ctx.addMutableState(sb, pattern, s " $pattern = new $sb(); " )
963+ ctx.addMutableState(df, numberFormat, s """ $numberFormat = new $df(""); """ )
964+
965+ s """
966+ if ( $d >= 0) {
967+ $pattern.delete(0, $pattern.length());
968+ if ( $d != $lastDValue) {
969+ $pattern.append("#,###,###,###,###,###,##0");
970+
971+ if ( $d > 0) {
972+ $pattern.append(".");
973+ for (int $i = 0; $i < $d; $i++) {
974+ $pattern.append("0");
975+ }
976+ }
977+ $df $dFormat = new $df( $pattern.toString());
978+ $lastDValue = $d;
979+ $numberFormat.applyPattern( $dFormat.toPattern());
980+ ${ev.primitive} = UTF8String.fromString( $numberFormat.format( ${typeHelper(num)}));
981+ }
982+ } else {
983+ ${ev.primitive} = null;
984+ ${ev.isNull} = true;
985+ }
986+ """
987+ })
988+ }
989+
950990 override def prettyName : String = " format_number"
951991}
952992
0 commit comments