From d7a35358e2068eca9bdead2b93f3b96dcaf890d8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 23:02:13 -0700 Subject: [PATCH 1/5] [SPARK-9303] Decimal should use java.math.Decimal directly instead of via Scala wrapper --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../org/apache/spark/sql/types/Decimal.scala | 50 ++++++++++--------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c66854d52c50..d4e319845bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -192,7 +192,7 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def decimalToTimestamp(d: Decimal): Long = { - (d.toBigDecimal * 1000000L).longValue() + d.toJavaBigDecimal.multiply(java.math.BigDecimal.valueOf(1000000L)).longValue() } private[this] def doubleToTimestamp(d: Double): Any = { if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index bc689810bc29..3e99d2999ca2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.math.{BigDecimal => JavaBigDecimal} + import org.apache.spark.annotation.DeveloperApi /** @@ -30,7 +32,7 @@ import org.apache.spark.annotation.DeveloperApi final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE} - private var decimalVal: BigDecimal = null + private var decimalVal: JavaBigDecimal = null private var longVal: Long = 0L private var _precision: Int = 1 private var _scale: Int = 0 @@ -44,7 +46,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(longVal: Long): Decimal = { if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { // We can't represent this compactly as a long without risking overflow - this.decimalVal = BigDecimal(longVal) + this.decimalVal = new JavaBigDecimal(longVal) this.longVal = 0L } else { this.decimalVal = null @@ -86,7 +88,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (precision < 19) { return null // Requested precision is too low to represent this value } - this.decimalVal = BigDecimal(unscaled) + this.decimalVal = new JavaBigDecimal(unscaled) this.longVal = 0L } else { val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) @@ -105,7 +107,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, with a given precision and scale. */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE).underlying() require(decimalVal.precision <= precision, "Overflowed precision") this.longVal = 0L this._precision = precision @@ -117,7 +119,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. */ def set(decimal: BigDecimal): Decimal = { - this.decimalVal = decimal + this.decimalVal = decimal.underlying() this.longVal = 0L this._precision = decimal.precision this._scale = decimal.scale @@ -135,19 +137,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { this } - def toBigDecimal: BigDecimal = { + def toBigDecimal: BigDecimal = BigDecimal(toJavaBigDecimal) + + def toJavaBigDecimal: JavaBigDecimal = { if (decimalVal.ne(null)) { decimalVal } else { - BigDecimal(longVal, _scale) + JavaBigDecimal.valueOf(longVal, _scale) } } - def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying() - def toUnscaledLong: Long = { if (decimalVal.ne(null)) { - decimalVal.underlying().unscaledValue().longValue() + decimalVal.unscaledValue().longValue() } else { longVal } @@ -164,9 +166,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def toDouble: Double = toBigDecimal.doubleValue() + def toDouble: Double = toJavaBigDecimal.doubleValue() - def toFloat: Float = toBigDecimal.floatValue() + def toFloat: Float = toJavaBigDecimal.floatValue() def toLong: Long = { if (decimalVal.eq(null)) { @@ -208,7 +210,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { longVal *= POW_10(diff) } else { // Give up on using Longs; switch to BigDecimal, which we'll modify below - decimalVal = BigDecimal(longVal, _scale) + decimalVal = JavaBigDecimal.valueOf(longVal, _scale) } } // In both cases, we will check whether our precision is okay below @@ -217,7 +219,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { // We get here if either we started with a BigDecimal, or we switched to one because we would // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + val newVal = decimalVal.setScale(scale, ROUNDING_MODE.id) if (newVal.precision > precision) { return false } @@ -242,7 +244,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 } else { - toBigDecimal.compare(other.toBigDecimal) + toJavaBigDecimal.compareTo(other.toJavaBigDecimal) } } @@ -253,27 +255,27 @@ final class Decimal extends Ordered[Decimal] with Serializable { false } - override def hashCode(): Int = toBigDecimal.hashCode() + override def hashCode(): Int = toJavaBigDecimal.hashCode() def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 - def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + def + (that: Decimal): Decimal = Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal)) - def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + def - (that: Decimal): Decimal = Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal)) - def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + def * (that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal)) def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal)) def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal)) def remainder(that: Decimal): Decimal = this % that def unary_- : Decimal = { if (decimalVal.ne(null)) { - Decimal(-decimalVal) + Decimal(decimalVal.negate()) } else { Decimal(-longVal, precision, scale) } @@ -290,7 +292,7 @@ object Decimal { private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val BIG_DEC_ZERO = BigDecimal(0) + private val BIG_DEC_ZERO: JavaBigDecimal = JavaBigDecimal.valueOf(0) def apply(value: Double): Decimal = new Decimal().set(value) @@ -300,7 +302,7 @@ object Decimal { def apply(value: BigDecimal): Decimal = new Decimal().set(value) - def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value) + def apply(value: JavaBigDecimal): Decimal = new Decimal().set(value) def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) From 8eee859d1d55b8a44fb6612a36e90d9487368b62 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 20:42:59 -0700 Subject: [PATCH 2/5] clean up --- .../org/apache/spark/sql/types/Decimal.scala | 41 ++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 4176dabd2180..870c71cb6d18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import java.math.{BigDecimal => JavaBigDecimal, RoundingMode, MathContext} +import java.math.{MathContext, RoundingMode, BigDecimal => JavaBigDecimal} import org.apache.spark.annotation.DeveloperApi @@ -107,7 +107,21 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, with a given precision and scale. */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE).underlying() + set(decimal.underlying(), precision, scale) + } + + /** + * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. + */ + def set(decimal: BigDecimal): Decimal = { + set(decimal.underlying()) + } + + /** + * Set this Decimal to the given java.math.BigDecimal value, with a given precision and scale. + */ + private[sql] def set(decimal: JavaBigDecimal, precision: Int, scale: Int): Decimal = { + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) require(decimalVal.precision <= precision, "Overflowed precision") this.longVal = 0L this._precision = precision @@ -116,10 +130,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { } /** - * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. + * Set this Decimal to the given java.math.BigDecimal value, inheriting its precision and scale. */ - def set(decimal: BigDecimal): Decimal = { - this.decimalVal = decimal.underlying() + private[sql] def set(decimal: JavaBigDecimal): Decimal = { + this.decimalVal = decimal this.longVal = 0L this._precision = decimal.precision this._scale = decimal.scale @@ -262,7 +276,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def hashCode(): Int = toBigDecimal.hashCode() def isZero: Boolean = { - if (decimalVal.ne(null)) decimalVal.compareTo(BIG_DEC_ZERO) == 0 else longVal == 0 + if (decimalVal.ne(null)) decimalVal.compareTo(BIG_DEC_ZERO) == 0 + else longVal == 0 } def + (that: Decimal): Decimal = { @@ -291,8 +306,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT)) } - def % (that: Decimal): Decimal = + def % (that: Decimal): Decimal = { if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal)) + } def remainder(that: Decimal): Decimal = this % that @@ -309,15 +325,12 @@ final class Decimal extends Ordered[Decimal] with Serializable { object Decimal { private val ROUNDING_MODE = RoundingMode.HALF_UP - - /** Maximum number of decimal digits a Long can represent */ - val MAX_LONG_DIGITS = 18 - + private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, ROUNDING_MODE) private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val BIG_DEC_ZERO: JavaBigDecimal = JavaBigDecimal.valueOf(0) - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, ROUNDING_MODE) + /** Maximum number of decimal digits a Long can represent */ + val MAX_LONG_DIGITS = 18 val ZERO = Decimal(0) val ONE = Decimal(1) @@ -335,7 +348,7 @@ object Decimal { def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) - def apply(value: java.math.BigDecimal, precision: Int, scale: Int): Decimal = + def apply(value: JavaBigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) def apply(unscaled: Long, precision: Int, scale: Int): Decimal = From 7b70c28a0043957f3ed25eaa62f14a085a45ae05 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 21:04:57 -0700 Subject: [PATCH 3/5] fix build --- .../scala/org/apache/spark/sql/types/Decimal.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 870c71cb6d18..e63e6857ab5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -237,7 +237,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { // We get here if either we started with a BigDecimal, or we switched to one because we would // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE.id) + val newVal = decimalVal.setScale(scale, ROUNDING_MODE) if (newVal.precision > precision) { return false } @@ -324,17 +324,17 @@ final class Decimal extends Ordered[Decimal] with Serializable { } object Decimal { - private val ROUNDING_MODE = RoundingMode.HALF_UP - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, ROUNDING_MODE) - private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val BIG_DEC_ZERO: JavaBigDecimal = JavaBigDecimal.valueOf(0) - /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 val ZERO = Decimal(0) val ONE = Decimal(1) + private val ROUNDING_MODE = RoundingMode.HALF_UP + private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, ROUNDING_MODE) + private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) + private val BIG_DEC_ZERO: JavaBigDecimal = JavaBigDecimal.valueOf(0) + def apply(value: Double): Decimal = new Decimal().set(value) def apply(value: Long): Decimal = new Decimal().set(value) From c8610016a1581211d4a70c6b66be60df7104ac45 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 22:52:50 -0700 Subject: [PATCH 4/5] fix test --- .../org/apache/spark/sql/catalyst/expressions/arithmetic.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ddf126ead993..b4af1aa05e1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -450,7 +450,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { val decimalAdd = "$plus" s""" ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); - if (r.compare(Decimal.ZERO) < 0) { + if (r.compare(Decimal.ZERO()) < 0) { ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2); } else { ${ev.primitive} = r; From 56190ef6331e005c8b9c7355031703e6a5f28ddf Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 8 Aug 2015 00:13:46 -0700 Subject: [PATCH 5/5] simplify Decimal --- .../spark/sql/catalyst/expressions/Cast.scala | 6 +- .../org/apache/spark/sql/types/Decimal.scala | 174 ++++-------------- .../sql/types/decimal/DecimalSuite.scala | 52 ++---- .../spark/unsafe/PlatformDependent.java | 8 + 4 files changed, 57 insertions(+), 183 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 261df419b12c..474b057c35ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -562,16 +562,16 @@ case class Cast(child: Expression, dataType: DataType) java.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ - case DecimalType() => + case dt: DecimalType => (c, evPrim, evNull) => s""" Decimal tmpDecimal = $c.clone(); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ - case LongType => + case ByteType | ShortType | IntegerType | LongType => (c, evPrim, evNull) => s""" - Decimal tmpDecimal = Decimal.apply($c); + Decimal tmpDecimal = Decimal.apply((long) $c); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ case x: NumericType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index e63e6857ab5f..8d831c5920c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import java.math.{MathContext, RoundingMode, BigDecimal => JavaBigDecimal} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.unsafe.PlatformDependent /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -32,28 +33,18 @@ import org.apache.spark.annotation.DeveloperApi final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal._ - private var decimalVal: JavaBigDecimal = null - private var longVal: Long = 0L + private var decimalVal: JavaBigDecimal = BIG_DEC_ZERO private var _precision: Int = 1 - private var _scale: Int = 0 def precision: Int = _precision - def scale: Int = _scale + def scale: Int = decimalVal.scale() /** * Set this Decimal to the given Long. Will have precision 20 and scale 0. */ def set(longVal: Long): Decimal = { - if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { - // We can't represent this compactly as a long without risking overflow - this.decimalVal = new JavaBigDecimal(longVal) - this.longVal = 0L - } else { - this.decimalVal = null - this.longVal = longVal - } - this._precision = 20 - this._scale = 0 + decimalVal = JavaBigDecimal.valueOf(longVal) + _precision = 20 this } @@ -61,45 +52,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given Int. Will have precision 10 and scale 0. */ def set(intVal: Int): Decimal = { - this.decimalVal = null - this.longVal = intVal - this._precision = 10 - this._scale = 0 + decimalVal = JavaBigDecimal.valueOf(intVal) + _precision = 10 this } /** * Set this Decimal to the given unscaled Long, with a given precision and scale. + * + * Note: this is used in serialization, caller will make sure that it will not overflow */ def set(unscaled: Long, precision: Int, scale: Int): Decimal = { - if (setOrNull(unscaled, precision, scale) == null) { - throw new IllegalArgumentException("Unscaled value too large for precision") - } - this - } - - /** - * Set this Decimal to the given unscaled Long, with a given precision and scale, - * and return it, or return null if it cannot be set due to overflow. - */ - def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = { - if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) { - // We can't represent this compactly as a long without risking overflow - if (precision < 19) { - return null // Requested precision is too low to represent this value - } - this.decimalVal = new JavaBigDecimal(unscaled) - this.longVal = 0L - } else { - val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) - if (unscaled <= -p || unscaled >= p) { - return null // Requested precision is too low to represent this value - } - this.decimalVal = null - this.longVal = unscaled - } - this._precision = precision - this._scale = scale + decimalVal = JavaBigDecimal.valueOf(unscaled, scale) + _precision = precision this } @@ -121,11 +86,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given java.math.BigDecimal value, with a given precision and scale. */ private[sql] def set(decimal: JavaBigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + decimalVal = decimal.setScale(scale, ROUNDING_MODE) require(decimalVal.precision <= precision, "Overflowed precision") - this.longVal = 0L - this._precision = precision - this._scale = scale + _precision = precision this } @@ -134,9 +97,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ private[sql] def set(decimal: JavaBigDecimal): Decimal = { this.decimalVal = decimal - this.longVal = 0L this._precision = decimal.precision - this._scale = decimal.scale this } @@ -145,52 +106,36 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ def set(decimal: Decimal): Decimal = { this.decimalVal = decimal.decimalVal - this.longVal = decimal.longVal this._precision = decimal._precision - this._scale = decimal._scale this } def toBigDecimal: BigDecimal = BigDecimal(toJavaBigDecimal) - private[sql] def toJavaBigDecimal: JavaBigDecimal = { - if (decimalVal.ne(null)) { - decimalVal - } else { - JavaBigDecimal.valueOf(longVal, _scale) - } - } + private[sql] def toJavaBigDecimal: JavaBigDecimal = decimalVal def toUnscaledLong: Long = { - if (decimalVal.ne(null)) { - decimalVal.unscaledValue().longValue() + val unscaled = PlatformDependent.UNSAFE.getLong(decimalVal, + PlatformDependent.BIG_DECIMAL_INTCOMPACT_OFFSET) + if (unscaled != Long.MinValue) { + unscaled } else { - longVal + decimalVal.unscaledValue().longValue() } } - override def toString: String = toJavaBigDecimal.toString() + override def toString: String = decimalVal.toString() @DeveloperApi def toDebugString: String = { - if (decimalVal.ne(null)) { - s"Decimal(expanded,$decimalVal,$precision,$scale})" - } else { - s"Decimal(compact,$longVal,$precision,$scale})" - } + s"Decimal($decimalVal,${_precision})" } def toDouble: Double = toJavaBigDecimal.doubleValue() def toFloat: Float = toJavaBigDecimal.floatValue() - def toLong: Long = { - if (decimalVal.eq(null)) { - longVal / POW_10(_scale) - } else { - decimalVal.longValue() - } - } + def toLong: Long = decimalVal.longValue() def toInt: Int = toLong.toInt @@ -205,65 +150,23 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ def changePrecision(precision: Int, scale: Int): Boolean = { // fast path for UnsafeProjection - if (precision == this.precision && scale == this.scale) { + if (precision == _precision && scale == decimalVal.scale()) { return true } - // First, update our longVal if we can, or transfer over to using a BigDecimal - if (decimalVal.eq(null)) { - if (scale < _scale) { - // Easier case: we just need to divide our scale down - val diff = _scale - scale - val droppedDigits = longVal % POW_10(diff) - longVal /= POW_10(diff) - if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { - longVal += (if (longVal < 0) -1L else 1L) - } - } else if (scale > _scale) { - // We might be able to multiply longVal by a power of 10 and not overflow, but if not, - // switch to using a BigDecimal - val diff = scale - _scale - val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0)) - if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) { - // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS - longVal *= POW_10(diff) - } else { - // Give up on using Longs; switch to BigDecimal, which we'll modify below - decimalVal = JavaBigDecimal.valueOf(longVal, _scale) - } - } - // In both cases, we will check whether our precision is okay below - } - if (decimalVal.ne(null)) { - // We get here if either we started with a BigDecimal, or we switched to one because we would - // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE) - if (newVal.precision > precision) { - return false - } - decimalVal = newVal - } else { - // We're still using Longs, but we should check whether we match the new precision - val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) - if (longVal <= -p || longVal >= p) { - // Note that we shouldn't have been able to fix this by switching to BigDecimal - return false - } + val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + if (newVal.precision > precision) { + return false } - + decimalVal = newVal _precision = precision - _scale = scale true } override def clone(): Decimal = new Decimal().set(this) override def compare(other: Decimal): Int = { - if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { - if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 - } else { - toJavaBigDecimal.compareTo(other.toJavaBigDecimal) - } + toJavaBigDecimal.compareTo(other.toJavaBigDecimal) } override def equals(other: Any): Boolean = other match { @@ -276,24 +179,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def hashCode(): Int = toBigDecimal.hashCode() def isZero: Boolean = { - if (decimalVal.ne(null)) decimalVal.compareTo(BIG_DEC_ZERO) == 0 - else longVal == 0 + decimalVal.compareTo(BIG_DEC_ZERO) == 0 } def + (that: Decimal): Decimal = { - if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { - Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale) - } else { - Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale) - } + Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale) } def - (that: Decimal): Decimal = { - if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { - Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale) - } else { - Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale) - } + Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale) } // HiveTypeCoercion will take care of the precision, scale of result @@ -313,11 +207,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def remainder(that: Decimal): Decimal = this % that def unary_- : Decimal = { - if (decimalVal.ne(null)) { - Decimal(decimalVal.negate(), precision, scale) - } else { - Decimal(-longVal, precision, scale) - } + Decimal(decimalVal.negate(), precision, scale) } def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 6921d15958a5..911787af9b57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -46,11 +46,8 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0) checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0) checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0) - intercept[IllegalArgumentException](Decimal(170L, 2, 1)) - intercept[IllegalArgumentException](Decimal(170L, 2, 0)) intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1)) intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1)) - intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) } test("creating decimals with negative scale") { @@ -88,36 +85,19 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { checkValues(Decimal(Double.MinValue), Double.MinValue, 0L) } - // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs - private val decimalVal = PrivateMethod[BigDecimal]('decimalVal) - - /** Check whether a decimal is represented compactly (passing whether we expect it to be) */ - private def checkCompact(d: Decimal, expected: Boolean): Unit = { - val isCompact = d.invokePrivate(decimalVal()).eq(null) - assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact") - } - - test("small decimals represented as unscaled long") { - checkCompact(new Decimal(), true) - checkCompact(Decimal(BigDecimal(10.03)), false) - checkCompact(Decimal(BigDecimal(1e20)), false) - checkCompact(Decimal(17L), true) - checkCompact(Decimal(17), true) - checkCompact(Decimal(17L, 2, 1), true) - checkCompact(Decimal(170L, 4, 2), true) - checkCompact(Decimal(17L, 24, 1), true) - checkCompact(Decimal(1e16.toLong), true) - checkCompact(Decimal(1e17.toLong), true) - checkCompact(Decimal(1e18.toLong - 1), true) - checkCompact(Decimal(- 1e18.toLong + 1), true) - checkCompact(Decimal(1e18.toLong - 1, 30, 10), true) - checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true) - checkCompact(Decimal(1e18.toLong), false) - checkCompact(Decimal(-1e18.toLong), false) - checkCompact(Decimal(1e18.toLong, 30, 10), false) - checkCompact(Decimal(-1e18.toLong, 30, 10), false) - checkCompact(Decimal(Long.MaxValue), false) - checkCompact(Decimal(Long.MinValue), false) + test("change precision and scale") { + assert(true === Decimal(5).changePrecision(1, 0)) + assert(false === Decimal(15).changePrecision(1, 0)) + assert(true === Decimal(5).changePrecision(2, 1)) + assert(false === Decimal(5).changePrecision(2, 2)) + assert(true === Decimal(0).changePrecision(1, 0)) + assert(true === Decimal(BigDecimal("10.5")).changePrecision(3, 0)) + assert(true === Decimal(BigDecimal("10.5")).changePrecision(3, 1)) + assert(false === Decimal(BigDecimal("10.5")).changePrecision(3, 2)) + assert(true === Decimal(BigDecimal("10.5")).changePrecision(4, 0)) + assert(true === Decimal(BigDecimal("10.5")).changePrecision(4, 1)) + assert(true === Decimal(BigDecimal("10.5")).changePrecision(4, 2)) + assert(false === Decimal(BigDecimal("10.5")).changePrecision(4, 3)) } test("hash code") { @@ -132,10 +112,6 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { } test("equals") { - // The decimals on the left are stored compactly, while the ones on the right aren't - checkCompact(Decimal(123), true) - checkCompact(Decimal(BigDecimal(123)), false) - checkCompact(Decimal("123"), false) assert(Decimal(123) === Decimal(BigDecimal(123))) assert(Decimal(123) === Decimal(BigDecimal("123.00"))) assert(Decimal(-123) === Decimal(BigDecimal(-123))) @@ -187,7 +163,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(b.toDouble === 0.125) } - test("set/setOrNull") { + test("set") { assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L) assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java index b2de2a2590f0..e193436d9a67 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe; import java.lang.reflect.Field; +import java.math.BigDecimal; import java.math.BigInteger; import sun.misc.Unsafe; @@ -119,6 +120,7 @@ public static void freeMemory(long address) { // Support for resetting final fields while deserializing public static final long BIG_INTEGER_SIGNUM_OFFSET; public static final long BIG_INTEGER_MAG_OFFSET; + public static final long BIG_DECIMAL_INTCOMPACT_OFFSET; /** * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to @@ -145,21 +147,27 @@ public static void freeMemory(long address) { long signumOffset = 0; long magOffset = 0; + long intCompactOffset = 0; try { signumOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("signum")); magOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("mag")); + intCompactOffset = _UNSAFE.objectFieldOffset( + BigDecimal.class.getDeclaredField("intCompact")); } catch (Exception ex) { // should not happen } BIG_INTEGER_SIGNUM_OFFSET = signumOffset; BIG_INTEGER_MAG_OFFSET = magOffset; + BIG_DECIMAL_INTCOMPACT_OFFSET = intCompactOffset; } else { + // should not happen BYTE_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; LONG_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; BIG_INTEGER_SIGNUM_OFFSET = 0; BIG_INTEGER_MAG_OFFSET = 0; + BIG_DECIMAL_INTCOMPACT_OFFSET = 0; } }