Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add function round (squash)
  • Loading branch information
zhichao-li committed Jun 24, 2015
commit 4e66e8da92324fa07cf1476ffc282876d478e190
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ object FunctionRegistry {
expression[Pow]("power"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
expression[Round]("round"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@ package org.apache.spark.sql.catalyst.expressions
import java.lang.{Long => JLong}

import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String

import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

import scala.math.BigDecimal.RoundingMode
import scala.math.BigDecimal.RoundingMode._

/**
* A leaf expression specifically for math constants. Math constants expect no input.
* @param c The math constant.
Expand Down Expand Up @@ -308,6 +316,118 @@ case class Logarithm(left: Expression, right: Expression)
logCode + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
"""
}
}

case class Round(left: Expression, right: Expression)
extends Expression with trees.BinaryNode[Expression] with Serializable {

def this(left: Expression) = {
this(left, Literal(0))
}

override def nullable: Boolean = left.nullable || right.nullable

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why fixed Decimal is not supported for Round?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it would support fixed Decimal.


override def checkInputDataTypes(): TypeCheckResult = {
if ((left.dataType.isInstanceOf[NumericType] || left.dataType.isInstanceOf[NullType])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale of Literal(l, LongType) when l > Int.MaxValue || l < Int.MinValue may considered invalid input in Hive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the scale type has been fixed to be Int , so I guess this checking is redundant?

&& (right.dataType.isInstanceOf[IntegerType] || right.dataType.isInstanceOf[NullType])) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"round accepts numeric types as the value and integer type as the scale")
}
}

override def toString: String = s"round($left, $right)"

override def dataType: DataType = left.dataType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we plan to support StringType and BinaryType as legitimate left of Round, just as Hive do, this is not true then.


override def eval(input: InternalRow): Any = {
val value = left.eval(input)
val scale = right.eval(input)
if (value == null || scale == null) {
null
} else {
dataType match {
case _: DecimalType => {
val result = value.asInstanceOf[Decimal]
result.set(result.toBigDecimal, result.precision, scale.asInstanceOf[Integer])
result
}
case FloatType => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenericUDFRound would return null if Float.isNaN(f) or Float.isInfinite(f)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, would add that same as hive.

BigDecimal.valueOf(value.asInstanceOf[Float].toDouble)
.setScale(scale.asInstanceOf[Integer], RoundingMode.HALF_UP).floatValue()
}
case DoubleType => {
BigDecimal.valueOf(value.asInstanceOf[Double])
.setScale(scale.asInstanceOf[Integer], RoundingMode.HALF_UP).doubleValue()
}
case LongType => {
BigDecimal.valueOf(value.asInstanceOf[Long])
.setScale(scale.asInstanceOf[Integer], RoundingMode.HALF_UP).longValue()
}
case IntegerType => {
BigDecimal.valueOf(value.asInstanceOf[Integer].toInt)
.setScale(scale.asInstanceOf[Integer], RoundingMode.HALF_UP).intValue()
}
case ShortType => {
BigDecimal.valueOf(value.asInstanceOf[Short])
.setScale(scale.asInstanceOf[Integer], RoundingMode.HALF_UP).shortValue()
}
case ByteType => {
BigDecimal.valueOf(value.asInstanceOf[Byte])
.setScale(scale.asInstanceOf[Integer], RoundingMode.HALF_UP).byteValue()
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to support StringType and BinaryType as well.

}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, (c1, c2) =>
s"$c1.set($c1.toBigDecimal(), $c1.precision(), $c2)")
case FloatType => defineCodeGen(ctx, ev, (c1, c2) =>
s"java.math.BigDecimal.valueOf((double)$c1)" +
s".setScale(Integer.valueOf($c2), java.math.RoundingMode.HALF_UP).floatValue()")
case DoubleType => defineCodeGen(ctx, ev, (c1, c2) =>
s"java.math.BigDecimal.valueOf((double)$c1)" +
s".setScale(Integer.valueOf($c2), java.math.RoundingMode.HALF_UP).doubleValue()")
case LongType => defineCodeGen(ctx, ev, (c1, c2) =>
s"java.math.BigDecimal.valueOf($c1)" +
s".setScale(Integer.valueOf($c2), java.math.RoundingMode.HALF_UP).longValue()")
case IntegerType => defineCodeGen(ctx, ev, (c1, c2) =>
s"java.math.BigDecimal.valueOf((long)$c1)" +
s".setScale(Integer.valueOf($c2), java.math.RoundingMode.HALF_UP).intValue()")
case ShortType => defineCodeGen(ctx, ev, (c1, c2) =>
s"java.math.BigDecimal.valueOf((long)$c1)" +
s".setScale(Integer.valueOf($c2), java.math.RoundingMode.HALF_UP).shortValue()")
case ByteType => defineCodeGen(ctx, ev, (c1, c2) =>
s"java.math.BigDecimal.valueOf((long)$c1)" +
s".setScale(Integer.valueOf($c2), java.math.RoundingMode.HALF_UP).byteValue()")
}

protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitive, eval2.primitive)

s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
${ev.primitive} = $resultCode;
} else {
${ev.isNull} = true;
}
}
"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{DataType, DoubleType, LongType}
import org.apache.spark.sql.types.{DataType, LongType}
import org.apache.spark.sql.types.{Decimal, DoubleType}

class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -252,4 +252,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null,
create_row(null))
}

test("round") {
checkEvaluation(Round(Literal(Decimal(1.26)), Literal(1)), Decimal(1.3, 3, 1))
checkEvaluation(Round(Literal(1.23D), Literal(1)), 1.2)
checkEvaluation(Round(Literal(1.25D), Literal(1)), 1.3)
checkEvaluation(Round(Literal(1.5F), Literal(0)), 2.0F)
checkEvaluation(Round(Literal(1.toShort), Literal(0)), 1.toShort)
checkEvaluation(Round(Literal(2.toByte), Literal(0)), 2.toByte)
checkEvaluation(Round(Literal(23456789L), Literal(0)), 23456789L)
checkEvaluation(Round(Literal(123), Literal(0)), 123)
}
}
32 changes: 32 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,38 @@ object functions {
*/
def rint(columnName: String): Column = rint(Column(columnName))

/**
* Computes rounded value of the given input.
*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column): Column = round(e.expr, 0)

/**
* Computes rounded value of the given input.
*
* @group math_funcs
* @since 1.5.0
*/
def round(columnName: String): Column = round(Column(columnName))

/**
* Computes rounded value of the given input.
*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column, scale: Int): Column = Round(e.expr, lit(scale).expr)

/**
* Computes rounded value of the given input.
*
* @group math_funcs
* @since 1.5.0
*/
def round(columnName: String, scale: Int): Column = round(Column(columnName), scale)

/**
* Computes the signum of the given value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{log => logarithm}
import org.apache.spark.sql.types.{Decimal, DecimalType, DoubleType}


private object MathExpressionsTestData {
Expand Down Expand Up @@ -292,4 +294,20 @@ class MathExpressionsSuite extends QueryTest {
checkAnswer(df.selectExpr("positive(b)"), Row(-1))
checkAnswer(df.selectExpr("positive(c)"), Row("abc"))
}

test("round") {
val df = Seq((1.53, 0.62, 12345L, 0.67.toFloat)).toDF("a", "b", "c", "d")
checkAnswer(df.select(round('a)), Row(2.0))
checkAnswer(df.select(round('b, 1)), Row(0.6))
checkAnswer(df.selectExpr("round(a)"), Row(2))
checkAnswer(df.selectExpr("round(b, 1)"), Row(0.6))
checkAnswer(df.selectExpr("round(c, 1)"), Row(12345L))
checkAnswer(df.selectExpr("round(d, 1)"), Row(0.7f))
checkAnswer(df.selectExpr("round(null)"), Row(null))
checkAnswer(df.selectExpr("round(null, 1)"), Row(null))
checkAnswer(df.selectExpr("round(145.23, -1)"), Row(150.0)) // same as hive
checkAnswer(df.selectExpr("round(20, 1)"), Row(20))
checkAnswer(df.selectExpr("round(1.0/0.0, 1)"), Row(null))

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_repeat",
"udf_rlike",
"udf_round",
// "udf_round_3", TODO: FIX THIS failed due to cast exception
"udf_round_3",
"udf_rpad",
"udf_rtrim",
"udf_second",
Expand Down