Skip to content

Commit 21425c8

Browse files
committed
[SPARK-9161][SQL] codegen FormatNumber
1 parent c6fe9b4 commit 21425c8

File tree

1 file changed

+54
-14
lines changed

1 file changed

+54
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -849,22 +849,15 @@ case class FormatNumber(x: Expression, d: Expression)
849849
@transient
850850
private val numberFormat: DecimalFormat = new DecimalFormat("")
851851

852-
override def eval(input: InternalRow): Any = {
853-
val xObject = x.eval(input)
854-
if (xObject == null) {
855-
return null
856-
}
857-
858-
val dObject = d.eval(input)
859-
860-
if (dObject == null || dObject.asInstanceOf[Int] < 0) {
852+
override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
853+
val dValue = dObject.asInstanceOf[Int]
854+
if (dValue < 0) {
861855
return null
862856
}
863-
val dValue = dObject.asInstanceOf[Int]
864857

865858
if (dValue != lastDValue) {
866859
// construct a new DecimalFormat only if a new dValue
867-
pattern.delete(0, pattern.length())
860+
pattern.delete(0, pattern.length)
868861
pattern.append("#,###,###,###,###,###,##0")
869862

870863
// decimal place
@@ -877,9 +870,10 @@ case class FormatNumber(x: Expression, d: Expression)
877870
pattern.append("0")
878871
}
879872
}
880-
val dFormat = new DecimalFormat(pattern.toString())
881-
lastDValue = dValue;
882-
numberFormat.applyPattern(dFormat.toPattern())
873+
val dFormat = new DecimalFormat(pattern.toString)
874+
lastDValue = dValue
875+
876+
numberFormat.applyPattern(dFormat.toPattern)
883877
}
884878

885879
x.dataType match {
@@ -894,6 +888,52 @@ case class FormatNumber(x: Expression, d: Expression)
894888
}
895889
}
896890

891+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
892+
nullSafeCodeGen(ctx, ev, (num, d) => {
893+
894+
def typeHelper(p: String): String = {
895+
x.dataType match {
896+
case _ : DecimalType => s"""$p.toJavaBigDecimal()"""
897+
case _ => s"$p"
898+
}
899+
}
900+
901+
val sb = classOf[StringBuffer].getName
902+
val df = classOf[DecimalFormat].getName
903+
val lastDValue = ctx.freshName("lastDValue")
904+
val pattern = ctx.freshName("pattern")
905+
val numberFormat = ctx.freshName("numberFormat")
906+
val i = ctx.freshName("i")
907+
val dFormat = ctx.freshName("dFormat")
908+
ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;")
909+
ctx.addMutableState(sb, pattern, s"$pattern = new $sb();")
910+
ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""")
911+
912+
s"""
913+
if ($d >= 0) {
914+
$pattern.delete(0, $pattern.length());
915+
if ($d != $lastDValue) {
916+
$pattern.append("#,###,###,###,###,###,##0");
917+
918+
if ($d > 0) {
919+
$pattern.append(".");
920+
for (int $i = 0; $i < $d; $i++) {
921+
$pattern.append("0");
922+
}
923+
}
924+
$df $dFormat = new $df($pattern.toString());
925+
$lastDValue = $d;
926+
$numberFormat.applyPattern($dFormat.toPattern());
927+
${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
928+
}
929+
} else {
930+
${ev.primitive} = null;
931+
${ev.isNull} = true;
932+
}
933+
"""
934+
})
935+
}
936+
897937
override def prettyName: String = "format_number"
898938
}
899939

0 commit comments

Comments
 (0)