Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-22036][SQL] Decimal multiplication with high precision/scale o…
…ften returns NULL
  • Loading branch information
mgaido91 committed Dec 19, 2017
commit 3037d4aa6afc4d7630d86d29b8dd7d7d724cc990
Original file line number Diff line number Diff line change
Expand Up @@ -93,41 +93,46 @@ object DecimalPrecision extends TypeCoercionRule {
case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e

case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
val resultScale = max(s1, s2)
val dt = DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)

case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
val resultScale = max(s1, s2)
val dt = DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)

case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
val resultType = DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)

case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
val diff = (intDig + decDig) - DecimalType.MAX_SCALE
if (diff > 0) {
decDig -= diff / 2 + 1
intDig = DecimalType.MAX_SCALE - decDig
}
val resultType = DecimalType.bounded(intDig + decDig, decDig)
// From https://msdn.microsoft.com/en-us/library/ms190476.aspx
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
val intDig = p1 - s1 + s2
val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
val prec = intDig + scale
val resultType = DecimalType.adjustPrecisionScale(prec, scale)
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)

case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
val resultType = DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2),
max(s1, s2))
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)

case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
val resultType = DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2),
max(s1, s2))
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
Expand Down Expand Up @@ -243,17 +248,43 @@ object DecimalPrecision extends TypeCoercionRule {
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
(left.dataType, right.dataType) match {
case (t: IntegralType, DecimalType.Fixed(p, s)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel it's more readable to just put the new cases for literal before these 4 cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately this is not really feasible since we match on different thigs: here we match on left.dataType and right.dataType, while for literals we match on left and right

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do

(left, right) match {
  case (l: Literal, r) => ...
  
  case (DecimalType.Expression(p, s), r @ IntergralType()) => ...
}

b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right))
case (DecimalType.Fixed(p, s), t: IntegralType) =>
b.makeCopy(Array(left, Cast(right, DecimalType.forType(t))))
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
b.makeCopy(Array(left, Cast(right, DoubleType)))
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
b.makeCopy(Array(Cast(left, DoubleType), right))
case _ =>
b
}
nondecimalLiteralAndDecimal(b).lift((left, right)).getOrElse(
nondecimalNonliteralAndDecimal(b).applyOrElse((left.dataType, right.dataType),
(_: (DataType, DataType)) => b))
}

/**
* Type coercion for BinaryOperator in which one side is a non-decimal literal numeric, and the
* other side is a decimal.
*/
private def nondecimalLiteralAndDecimal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this rule newly introduced?

Copy link
Contributor Author

@mgaido91 mgaido91 Dec 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is. If we don't introduce this, we have a failure in Hive compatibility tests, because Hive use the exact precision and scale needed by the literals, while we, before this change, were using conservative values for each type. For instance, if we have a select 123.12345 * 3, before this change 3 would have been interpreted as Decimal(10, 0), which is the type for integers. After the change, 3 would become Decimal(1, 0), as Hive does. This prevents from needing more precision that what is actually needed.

b: BinaryOperator): PartialFunction[(Expression, Expression), Expression] = {
// Promote literal integers inside a binary expression with fixed-precision decimals to
// decimals. The precision and scale are the ones needed by the integer value.
case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType]
&& l.dataType.isInstanceOf[IntegralType] =>
b.makeCopy(Array(Cast(l, DecimalType.forLiteral(l)), r))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we don't do this? Requiring more precision seems OK as now we allow precision lose.

Copy link
Contributor Author

@mgaido91 mgaido91 Jan 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we don't do this we have many test failure in spark-hive, because hive does so. Moreover, requiring more precision is not OK, since it leads to a useless loss of precision. Think of this example: you multiply a column which is DECIMAL(38, 18) by 2. If you don't do this, 2 is considered a DECIMAL(10, 0). According to the rules, the result should be DECIMAL(38 + 10 + 1, 18), which is out of range: then according to the rules it becomes DECIMAL(38, 7), leading to potentially loosing 11 digits of the fractional part. With this change, instead, the result would be DECIMAL(38 + 1 + 1, 18), which becomes DECIMAL(38, 16), safely having a much lower precision loss.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this example as the code comment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hive is also doing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, Hive is doing so. That is the reason why I introduced the change (without it, we would have had test failures in spark hive). I will add this in the comment.

case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType]
&& r.dataType.isInstanceOf[IntegralType] =>
b.makeCopy(Array(l, Cast(r, DecimalType.forLiteral(r))))
}

/**
* Type coercion for BinaryOperator in which one side is a non-decimal non-literal numeric, and
* the other side is a decimal.
*/
private def nondecimalNonliteralAndDecimal(
b: BinaryOperator): PartialFunction[(DataType, DataType), Expression] = {
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case (t: IntegralType, DecimalType.Fixed(p, s)) =>
b.makeCopy(Array(Cast(b.left, DecimalType.forType(t)), b.right))
case (DecimalType.Fixed(_, _), t: IntegralType) =>
b.makeCopy(Array(b.left, Cast(b.right, DecimalType.forType(t))))
case (t, DecimalType.Fixed(_, _)) if isFloat(t) =>
b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
case (DecimalType.Fixed(_, _), t) if isFloat(t) =>
b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(UTF8String.fromString(s), StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale))
case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d))
case d: JavaBigDecimal =>
Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale()))
case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}


/**
Expand Down Expand Up @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType {
val MAX_SCALE = 38
val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18)
val USER_DEFAULT: DecimalType = DecimalType(10, 0)
val MINIMUM_ADJUSTED_SCALE = 6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before naming a conf, I need to understand the rule you are following. https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql The SQL Server only applies MINIMUM_ADJUSTED_SCALE for multiplication and division. However, in your impl, you are using it for all the BinaryArithmetic operators?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I followed Hive's implementation which works like this and applies this 6 digits minimum to all operations. This means that SQLServer allows to round more digits than us in those cases, ie. we ensure at least 6 digits for the scale, while SQLServer doesn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile what about spark.sql.decimalOperations.mode which defaults to native and accepts also hive (and in future also sql2011 for throwing exception instead of returning NULL)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about spark.sql.decimalOperations.allowTruncat? Let's leave the mode stuff to the type coercion mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make it an internal conf and remove it after some releases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll go with that, thanks @cloud-fan.


// The decimal types compatible with other numeric types
private[sql] val ByteDecimal = DecimalType(3, 0)
Expand All @@ -136,10 +137,54 @@ object DecimalType extends AbstractDataType {
case DoubleType => DoubleDecimal
}

private[sql] def forLiteral(literal: Literal): DecimalType = literal.value match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different than forType if applied on Literal.dataType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, please see my comment above for an example. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fromLiteral?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we have forType I used forLiteral to be coherent on the naming

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we also have fromDecimal...

case v: Short => fromBigDecimal(BigDecimal(v))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just use ShortDecimal, IntDecimal...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, please see my comments above.

case v: Int => fromBigDecimal(BigDecimal(v))
case v: Long => fromBigDecimal(BigDecimal(v))
case _ => forType(literal.dataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  private[sql] def forType(dataType: DataType): DecimalType = dataType match {
    case ByteType => ByteDecimal
    case ShortType => ShortDecimal
    case IntegerType => IntDecimal
    case LongType => LongDecimal
    case FloatType => FloatDecimal
    case DoubleType => DoubleDecimal
  }

This list is incomplete. Is that possible, the input literal is Literal(null, NullType)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this problem was present before this PR. Should we fix it here? Is this fix needed? I guess that if it would have been a problem, it would already have been reported.

}

private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = {
DecimalType(Math.max(d.precision, d.scale), d.scale)
}

private[sql] def bounded(precision: Int, scale: Int): DecimalType = {
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
}

// scalastyle:off line.size.limit
/**
* Decimal implementation is based on Hive's one, which is itself inspired to SQLServer's one.
* In particular, when a result precision is greater than {@link #MAX_PRECISION}, the
* corresponding scale is reduced to prevent the integral part of a result from being truncated.
*
* For further reference, please see
* https://blogs.msdn.microsoft.com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this blog link can be available for long time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the web link to the commercial products.

*
* @param precision
* @param scale
* @return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the above 3 lines

*/
// scalastyle:on line.size.limit
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logics in this adjustment function is also different from the MS SQL Server docs.

In multiplication and division operations we need precision - scale places to store the integral part of the result. The scale might be reduced using the following rules:
The resulting scale is reduced to min(scale, 38 – (precision-scale)) if the integral part is less than 32, because it cannot be greater than 38 – (precision-scale). Result might be rounded in this case.
The scale will not be changed if it is less than 6 and if the integral part is greater than 32. In this case, overflow error might be raised if it cannot fit into decimal(38, scale)
The scale will be set to 6 if it is greater than 6 and if the integral part is greater than 32. In this case, both integral part and scale would be reduced and resulting type is decimal(38,6). Result might be rounded to 6 decimal places or overflow error will be thrown if integral part cannot fit into 32 digits.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, but I think this is exactly the same which is described there. The implementation might seem doing different things but actually the result will be the same. They both take the min between 6 and the desired scale if the precision is not enough to represent the whole scale.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the rule in document is

val resultPrecision = 38
if (intDigits < 32) { // This means scale > 6, as iniDigits = precision - scale and precision > 38
  val maxScale = 38 - intDigits
  val resultScale = min(scale, maxScale)
} else {
  if (scale < 6) {
    // can't round as scale is already small
    val resultScale = scale
  } else {
    val resltScale = 6
  }
}

I think this is a little different from the current rule

val minScaleValue = Math.min(scale, 6)
val resultScale = max(38 - intDigits, minScaleValue)

Think aboout the case iniDigits < 32, SQL server is min(scale, 38 - intDigits), we are 38 - intDigits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan yes, but you have to keep in mind that we are doing so only when precision is > 38. With some simple math (given intDigits = precision - scale), SQL server is min(scale, scale + 38 - precision). Since we perform this operation only when precision is greater than 38, the second member is always the minimum. Which means that in such a case, SQL server behaves like us, ie. it takes always 38 - intDigits. When precision is < than 38, instead we return the input precision and scale, as SQL server does. We are just using the precision instead of the intDigits for the if.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, makes sense

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this part is consistent.

// Assumptions:
// precision >= scale
// scale >= 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use assert to make sure assumptions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add it even though it is not needed... there is no way we can violate those constraints. If you believe it is better to use assert, I will do that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use assert for assumptions, not comments.

if (precision <= MAX_PRECISION) {
// Adjustment only needed when we exceed max precision
DecimalType(precision, scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we also prevent scale > MAX_SCALE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is prevented outside this function.

} else {
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
val intDigits = precision - scale
// If original scale less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
// preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like MAXIMUM_ADJUSTED_SCALE instead of MINIMUM_ADJUSTED_SCALE.

Copy link
Contributor Author

@mgaido91 mgaido91 Dec 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the MINIMUM_ADJUSTED_SCALE. We can't have a scale lower that that, even though we would need not to loose precision. Please see the comments above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't have a scale lower that that...

Don't you get a scale lower than MINIMUM_ADJUSTED_SCALE by Math.min(scale, MINIMUM_ADJUSTED_SCALE)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry, my answer was very poor, I will rephrase. scale contains the scale which we need to represent the values without any precision loss. What we are doing here is saying that the lower bound for the scale is either the scale that we need to correctly represent the value or the MINIMUM_ADJUSTED_SCALE. After this, in the line below we state that the scale we will use is the max between the number of digits of the precision we don't need on the left of the dot and this minScaleValue: ie. even though in some cases we might need a scale higher than MINIMUM_ADJUSTED_SCALE, but the number of digits needed on the left on the dot would force us to have a scale lower than MINIMUM_ADJUSTED_SCALE, we enforce that we will maintain at least MINIMUM_ADJUSTED_SCALE. We can't let the scale be lower that this threshold, even though it would be needed to enforce that we don't loose digits on the left of the dot. Please refer also to the blog post I linked in the comment above for further (hopefully better) explanation.

val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like Math.min?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is max because we take either the scale which would prevent a loss of "space" for intDigits, ie. the part on the left of the dot, or the minScaleValue, which is the scale we are ensuring to provide at least.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line needs some comments.


DecimalType(MAX_PRECISION, adjustedScale)
}
}

override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT

override private[sql] def acceptsType(other: DataType): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType)
assertExpressionType(sum(Divide(1, 2.0f)), DoubleType)
assertExpressionType(sum(Divide(1.0f, 2)), DoubleType)
assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11))
assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11))
assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(22, 11))
assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(26, 6))
assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {

test("maximum decimals") {
for (expr <- Seq(d1, d2, i, u)) {
checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT)
checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT)
checkType(Add(expr, u), DecimalType(38, 17))
checkType(Subtract(expr, u), DecimalType(38, 17))
}

checkType(Multiply(d1, u), DecimalType(38, 19))
checkType(Multiply(d2, u), DecimalType(38, 20))
checkType(Multiply(i, u), DecimalType(38, 18))
checkType(Multiply(u, u), DecimalType(38, 36))
checkType(Multiply(d1, u), DecimalType(38, 16))
checkType(Multiply(d2, u), DecimalType(38, 14))
checkType(Multiply(i, u), DecimalType(38, 7))
checkType(Multiply(u, u), DecimalType(38, 6))

checkType(Divide(u, d1), DecimalType(38, 18))
checkType(Divide(u, d2), DecimalType(38, 19))
checkType(Divide(u, i), DecimalType(38, 23))
checkType(Divide(u, u), DecimalType(38, 18))
checkType(Divide(u, d1), DecimalType(38, 17))
checkType(Divide(u, d2), DecimalType(38, 16))
checkType(Divide(u, i), DecimalType(38, 18))
checkType(Divide(u, u), DecimalType(38, 6))

checkType(Remainder(d1, u), DecimalType(19, 18))
checkType(Remainder(d2, u), DecimalType(21, 18))
Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/decimals.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- tests for decimals handling in operations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why create a new test file instead of adding more cases in decimalArithmeticOperations.sql?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because that file was meant for the typeCoercion modes (eg. if we introduce a sql2011 mode which throws exception instead of returning NULL), while this is more generic about arithmetic operations' behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that file is under .../typeCoercion/native/, which is meant for the default behavior(native mode of type coercion). If we introduce a sql2001 mode, we will put a same file under .../typeCoercion/sql2001/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I'll merge this file into the other then. Thanks.

-- Spark draws its inspiration byt Hive implementation
Copy link
Member

@dongjoon-hyun dongjoon-hyun Dec 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hyperlinks in the PR came from Microsoft, and the first purpose is SQL compliant. Can we remove this line?

create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet;

insert into decimals_test values(1, 100.0, 999.0);
insert into decimals_test values(2, 12345.123, 12345.123);
insert into decimals_test values(3, 0.1234567891011, 1234.1);
insert into decimals_test values(4, 123456789123456789.0, 1.123456789123456789);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. How about making into one SQL statement?

insert into decimals_test values (1, 100.0, 999.0), (2, 12345.123, 12345.123), (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789)


-- test decimal operations
select id, a+b, a-b, a*b, a/b from decimals_test order by id;

-- test operations between decimals and constants
select id, a*10, b/10 from decimals_test order by id;

drop table decimals_test;
72 changes: 72 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/decimals.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 8


-- !query 0
create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet
-- !query 0 schema
struct<>
-- !query 0 output



-- !query 1
insert into decimals_test values(1, 100.0, 999.0)
-- !query 1 schema
struct<>
-- !query 1 output



-- !query 2
insert into decimals_test values(2, 12345.123, 12345.123)
-- !query 2 schema
struct<>
-- !query 2 output



-- !query 3
insert into decimals_test values(3, 0.1234567891011, 1234.1)
-- !query 3 schema
struct<>
-- !query 3 output



-- !query 4
insert into decimals_test values(4, 123456789123456789.0, 1.123456789123456789)
-- !query 4 schema
struct<>
-- !query 4 output



-- !query 5
select id, a+b, a-b, a*b, a/b from decimals_test order by id
-- !query 5 schema
struct<id:int,(a + b):decimal(38,17),(a - b):decimal(38,17),(a * b):decimal(38,6),(a / b):decimal(38,6)>
-- !query 5 output
1 1099 -899 99900 0.1001
2 24690.246 0 152402061.885129 1
3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001
4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109


-- !query 6
select id, a*10, b/10 from decimals_test order by id
-- !query 6 schema
struct<id:int,(CAST(a AS DECIMAL(38,18)) * CAST(CAST(10 AS DECIMAL(2,0)) AS DECIMAL(38,18))):decimal(38,15),(CAST(b AS DECIMAL(38,18)) / CAST(CAST(10 AS DECIMAL(2,0)) AS DECIMAL(38,18))):decimal(38,18)>
-- !query 6 output
1 1000 99.9
2 123451.23 1234.5123
3 1.234567891011 123.41
4 1234567891234567890 0.112345678912345679


-- !query 7
drop table decimals_test
-- !query 7 schema
struct<>
-- !query 7 output

Original file line number Diff line number Diff line change
Expand Up @@ -1526,15 +1526,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"),
Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38))))
checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"),
Row(null))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two cases (2 and 3) were mentioned in the email. If this is the only NULL-return test case from previous behavior, can we have another test case?

Currently, Spark behaves like follows:

   1. It follows some rules taken from intial Hive implementation;
   2. it returns NULL;
   3. it returns NULL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The third case is never checked in the current codebase, ie. when we go out of the representable range of values. I haven't added a test for it, because I was waiting for feedbacks by the community about how to handle the 3rd case and I focused this PR only on points 1 and 2. But I can add a test case for it and eventually change it in a future PR to address the 3rd point in the e-mail. Thanks.

Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38))))

checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333")))
checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333")))
checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333")))
checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"),
Row(BigDecimal("3.433333333333333333333333333", new MathContext(38))))
Row(BigDecimal("3.4333333333333333333", new MathContext(38))))
checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"),
Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38))))
Row(BigDecimal("3.4333333333333333333", new MathContext(38))))
}

test("SPARK-10215 Div of Decimal returns null") {
Expand Down