Skip to content

Commit 1ddd0f2

Browse files
tarekbeckerrxin
authored andcommitted
[SPARK-9161][SQL] codegen FormatNumber
Jira https://issues.apache.org/jira/browse/SPARK-9161 Author: Tarek Auel <tarek.auel@googlemail.com> Closes apache#7545 from tarekauel/SPARK-9161 and squashes the following commits: 21425c8 [Tarek Auel] [SPARK-9161][SQL] codegen FormatNumber
1 parent 228ab65 commit 1ddd0f2

File tree

1 file changed

+54
-14
lines changed

1 file changed

+54
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -902,22 +902,15 @@ case class FormatNumber(x: Expression, d: Expression)
902902
@transient
903903
private val numberFormat: DecimalFormat = new DecimalFormat("")
904904

905-
override def eval(input: InternalRow): Any = {
906-
val xObject = x.eval(input)
907-
if (xObject == null) {
905+
override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
906+
val dValue = dObject.asInstanceOf[Int]
907+
if (dValue < 0) {
908908
return null
909909
}
910910

911-
val dObject = d.eval(input)
912-
913-
if (dObject == null || dObject.asInstanceOf[Int] < 0) {
914-
return null
915-
}
916-
val dValue = dObject.asInstanceOf[Int]
917-
918911
if (dValue != lastDValue) {
919912
// construct a new DecimalFormat only if a new dValue
920-
pattern.delete(0, pattern.length())
913+
pattern.delete(0, pattern.length)
921914
pattern.append("#,###,###,###,###,###,##0")
922915

923916
// decimal place
@@ -930,9 +923,10 @@ case class FormatNumber(x: Expression, d: Expression)
930923
pattern.append("0")
931924
}
932925
}
933-
val dFormat = new DecimalFormat(pattern.toString())
934-
lastDValue = dValue;
935-
numberFormat.applyPattern(dFormat.toPattern())
926+
val dFormat = new DecimalFormat(pattern.toString)
927+
lastDValue = dValue
928+
929+
numberFormat.applyPattern(dFormat.toPattern)
936930
}
937931

938932
x.dataType match {
@@ -947,6 +941,52 @@ case class FormatNumber(x: Expression, d: Expression)
947941
}
948942
}
949943

944+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
945+
nullSafeCodeGen(ctx, ev, (num, d) => {
946+
947+
def typeHelper(p: String): String = {
948+
x.dataType match {
949+
case _ : DecimalType => s"""$p.toJavaBigDecimal()"""
950+
case _ => s"$p"
951+
}
952+
}
953+
954+
val sb = classOf[StringBuffer].getName
955+
val df = classOf[DecimalFormat].getName
956+
val lastDValue = ctx.freshName("lastDValue")
957+
val pattern = ctx.freshName("pattern")
958+
val numberFormat = ctx.freshName("numberFormat")
959+
val i = ctx.freshName("i")
960+
val dFormat = ctx.freshName("dFormat")
961+
ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;")
962+
ctx.addMutableState(sb, pattern, s"$pattern = new $sb();")
963+
ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""")
964+
965+
s"""
966+
if ($d >= 0) {
967+
$pattern.delete(0, $pattern.length());
968+
if ($d != $lastDValue) {
969+
$pattern.append("#,###,###,###,###,###,##0");
970+
971+
if ($d > 0) {
972+
$pattern.append(".");
973+
for (int $i = 0; $i < $d; $i++) {
974+
$pattern.append("0");
975+
}
976+
}
977+
$df $dFormat = new $df($pattern.toString());
978+
$lastDValue = $d;
979+
$numberFormat.applyPattern($dFormat.toPattern());
980+
${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
981+
}
982+
} else {
983+
${ev.primitive} = null;
984+
${ev.isNull} = true;
985+
}
986+
"""
987+
})
988+
}
989+
950990
override def prettyName: String = "format_number"
951991
}
952992

0 commit comments

Comments
 (0)